diff --git a/java/src/org/openqa/selenium/grid/node/ForwardWebDriverCommand.java b/java/src/org/openqa/selenium/grid/node/ForwardWebDriverCommand.java index 458dfff372cbe..fadcfff4c80fd 100644 --- a/java/src/org/openqa/selenium/grid/node/ForwardWebDriverCommand.java +++ b/java/src/org/openqa/selenium/grid/node/ForwardWebDriverCommand.java @@ -17,7 +17,13 @@ package org.openqa.selenium.grid.node; +import static java.net.HttpURLConnection.HTTP_INTERNAL_ERROR; +import static org.openqa.selenium.remote.HttpSessionId.getSessionId; +import static org.openqa.selenium.remote.http.Contents.asJson; + +import com.google.common.collect.ImmutableMap; import org.openqa.selenium.internal.Require; +import org.openqa.selenium.remote.SessionId; import org.openqa.selenium.remote.http.HttpHandler; import org.openqa.selenium.remote.http.HttpRequest; import org.openqa.selenium.remote.http.HttpResponse; @@ -30,8 +36,22 @@ class ForwardWebDriverCommand implements HttpHandler { this.node = Require.nonNull("Node", node); } + public boolean matches(HttpRequest req) { + return getSessionId(req.getUri()) + .map(id -> node.isSessionOwner(new SessionId(id))) + .orElse(false); + } + @Override public HttpResponse execute(HttpRequest req) { - return node.executeWebDriverCommand(req); + if (matches(req)) { + return node.executeWebDriverCommand(req); + } + return new HttpResponse() + .setStatus(HTTP_INTERNAL_ERROR) + .setContent( + asJson( + ImmutableMap.of( + "error", String.format("Session not found in node %s", node.getId())))); } } diff --git a/java/src/org/openqa/selenium/grid/node/Node.java b/java/src/org/openqa/selenium/grid/node/Node.java index 5d54bb3e2981d..09fe7d02ae9f5 100644 --- a/java/src/org/openqa/selenium/grid/node/Node.java +++ b/java/src/org/openqa/selenium/grid/node/Node.java @@ -152,7 +152,7 @@ protected Node( req -> getSessionId(req.getUri()) .map(SessionId::new) - .map(this::isSessionOwner) + .map(sessionId -> this.getSession(sessionId) != null) .orElse(false)) .to(() -> new ForwardWebDriverCommand(this)) .with(spanDecorator("node.forward_command")), diff --git a/java/src/org/openqa/selenium/grid/node/local/LocalNode.java b/java/src/org/openqa/selenium/grid/node/local/LocalNode.java index 2283e8573042d..7304db8a87847 100644 --- a/java/src/org/openqa/selenium/grid/node/local/LocalNode.java +++ b/java/src/org/openqa/selenium/grid/node/local/LocalNode.java @@ -297,7 +297,13 @@ protected LocalNode( heartbeatPeriod.getSeconds(), TimeUnit.SECONDS); - Runtime.getRuntime().addShutdownHook(new Thread(this::stopAllSessions)); + Runtime.getRuntime() + .addShutdownHook( + new Thread( + () -> { + stopAllSessions(); + drain(); + })); new JMXHelper().register(this); } @@ -316,7 +322,6 @@ private void stopTimedOutSession(RemovalNotification not } // Attempt to stop the session slot.stop(); - this.sessionToDownloadsDir.invalidate(id); // Decrement pending sessions if Node is draining if (this.isDraining()) { int done = pendingSessions.decrementAndGet(); @@ -473,8 +478,6 @@ public Either newSession( sessionToDownloadsDir.put(session.getId(), uuidForSessionDownloads); currentSessions.put(session.getId(), slotToUse); - checkSessionCount(); - SessionId sessionId = session.getId(); Capabilities caps = session.getCapabilities(); SESSION_ID.accept(span, sessionId); @@ -513,6 +516,8 @@ public Either newSession( span.addEvent("Unable to create session with the driver", attributeMap); return Either.left(possibleSession.left()); } + } finally { + checkSessionCount(); } } @@ -765,6 +770,10 @@ public HttpResponse uploadFile(HttpRequest req, SessionId id) { public void stop(SessionId id) throws NoSuchSessionException { Require.nonNull("Session ID", id); + if (sessionToDownloadsDir.getIfPresent(id) != null) { + sessionToDownloadsDir.invalidate(id); + } + SessionSlot slot = currentSessions.getIfPresent(id); if (slot == null) { throw new NoSuchSessionException("Cannot find session with id: " + id); diff --git a/java/test/org/openqa/selenium/grid/node/ForwardWebDriverCommandTest.java b/java/test/org/openqa/selenium/grid/node/ForwardWebDriverCommandTest.java new file mode 100644 index 0000000000000..8f0df29251870 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/node/ForwardWebDriverCommandTest.java @@ -0,0 +1,80 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.grid.node; + +import static java.net.HttpURLConnection.HTTP_INTERNAL_ERROR; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.*; +import static org.openqa.selenium.remote.http.Contents.asJson; + +import com.google.common.collect.ImmutableMap; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.openqa.selenium.grid.data.NodeId; +import org.openqa.selenium.remote.SessionId; +import org.openqa.selenium.remote.http.HttpRequest; +import org.openqa.selenium.remote.http.HttpResponse; + +class ForwardWebDriverCommandTest { + + private Node mockNode; + private ForwardWebDriverCommand command; + + @BeforeEach + void setUp() { + mockNode = mock(Node.class); + when(mockNode.getId()).thenReturn(new NodeId(UUID.randomUUID())); + command = new ForwardWebDriverCommand(mockNode); + } + + @Test + void testExecuteWithValidSessionOwner() { + HttpRequest mockRequest = mock(HttpRequest.class); + when(mockRequest.getUri()).thenReturn("/session/1234"); + + SessionId sessionId = new SessionId("1234"); + when(mockNode.isSessionOwner(sessionId)).thenReturn(true); + + HttpResponse expectedResponse = new HttpResponse(); + when(mockNode.executeWebDriverCommand(mockRequest)).thenReturn(expectedResponse); + + HttpResponse actualResponse = command.execute(mockRequest); + assertEquals(expectedResponse, actualResponse); + } + + @Test + void testExecuteWithInvalidSessionOwner() { + HttpRequest mockRequest = mock(HttpRequest.class); + when(mockRequest.getUri()).thenReturn("/session/5678"); + + SessionId sessionId = new SessionId("5678"); + when(mockNode.isSessionOwner(sessionId)).thenReturn(false); + + HttpResponse actualResponse = command.execute(mockRequest); + HttpResponse expectResponse = + new HttpResponse() + .setStatus(HTTP_INTERNAL_ERROR) + .setContent( + asJson( + ImmutableMap.of( + "error", String.format("Session not found in node %s", mockNode.getId())))); + assertEquals(expectResponse.getStatus(), actualResponse.getStatus()); + assertEquals(expectResponse.getContentEncoding(), actualResponse.getContentEncoding()); + } +} diff --git a/java/test/org/openqa/selenium/grid/node/NodeTest.java b/java/test/org/openqa/selenium/grid/node/NodeTest.java index b1c50ed0dc65a..c2a66d2162c4b 100644 --- a/java/test/org/openqa/selenium/grid/node/NodeTest.java +++ b/java/test/org/openqa/selenium/grid/node/NodeTest.java @@ -23,6 +23,8 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.assertj.core.api.InstanceOfAssertFactories.MAP; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.openqa.selenium.json.Json.MAP_TYPE; import static org.openqa.selenium.remote.http.Contents.string; import static org.openqa.selenium.remote.http.HttpMethod.DELETE; @@ -102,7 +104,9 @@ class NodeTest { private Tracer tracer; private EventBus bus; private LocalNode local; + private LocalNode local2; private Node node; + private Node node2; private ImmutableCapabilities stereotype; private ImmutableCapabilities caps; private URI uri; @@ -150,6 +154,7 @@ public HttpResponse execute(HttpRequest req) throws UncheckedIOException { builder = builder.enableManagedDownloads(true).sessionTimeout(Duration.ofSeconds(1)); } local = builder.build(); + local2 = builder.build(); node = new RemoteNode( @@ -160,6 +165,16 @@ public HttpResponse execute(HttpRequest req) throws UncheckedIOException { registrationSecret, local.getSessionTimeout(), ImmutableSet.of(caps)); + + node2 = + new RemoteNode( + tracer, + new PassthroughHttpClient.Factory(local2), + new NodeId(UUID.randomUUID()), + uri, + registrationSecret, + local2.getSessionTimeout(), + ImmutableSet.of(caps)); } @Test @@ -371,13 +386,36 @@ void shouldOnlyRespondToWebDriverCommandsForSessionsTheNodeOwns() { assertThatEither(response).isRight(); Session session = response.right().getSession(); + Either response2 = + node2.newSession(createSessionRequest(caps)); + assertThatEither(response2).isRight(); + Session session2 = response2.right().getSession(); + + // Assert that should respond to commands for sessions Node 1 owns HttpRequest req = new HttpRequest(POST, String.format("/session/%s/url", session.getId())); assertThat(local.matches(req)).isTrue(); assertThat(node.matches(req)).isTrue(); - req = new HttpRequest(POST, String.format("/session/%s/url", UUID.randomUUID())); - assertThat(local.matches(req)).isFalse(); - assertThat(node.matches(req)).isFalse(); + // Assert that should respond to commands for sessions Node 2 owns + HttpRequest req2 = new HttpRequest(POST, String.format("/session/%s/url", session2.getId())); + assertThat(local2.matches(req2)).isTrue(); + assertThat(node2.matches(req2)).isTrue(); + + // Assert that should not respond to commands for sessions Node 1 does not own + NoSuchSessionException exception = + assertThrows(NoSuchSessionException.class, () -> node.execute(req2)); + assertTrue( + exception + .getMessage() + .startsWith(String.format("Cannot find session with id: %s", session2.getId()))); + + // Assert that should not respond to commands for sessions Node 2 does not own + NoSuchSessionException exception2 = + assertThrows(NoSuchSessionException.class, () -> node2.execute(req)); + assertTrue( + exception2 + .getMessage() + .startsWith(String.format("Cannot find session with id: %s", session.getId()))); } @Test diff --git a/java/test/org/openqa/selenium/grid/router/StressTest.java b/java/test/org/openqa/selenium/grid/router/StressTest.java index a2c432229297f..3be76395e4f01 100644 --- a/java/test/org/openqa/selenium/grid/router/StressTest.java +++ b/java/test/org/openqa/selenium/grid/router/StressTest.java @@ -20,6 +20,8 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.concurrent.TimeUnit.MINUTES; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.StringReader; import java.util.LinkedList; @@ -33,7 +35,10 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.openqa.selenium.By; +import org.openqa.selenium.MutableCapabilities; +import org.openqa.selenium.NoSuchSessionException; import org.openqa.selenium.WebDriver; +import org.openqa.selenium.WebDriverException; import org.openqa.selenium.grid.config.MapConfig; import org.openqa.selenium.grid.config.MemoizedConfig; import org.openqa.selenium.grid.config.TomlConfig; @@ -65,7 +70,14 @@ public void setupServers() { DeploymentTypes.DISTRIBUTED.start( browser.getCapabilities(), new TomlConfig( - new StringReader("[node]\n" + "driver-implementation = " + browser.displayName()))); + new StringReader( + "[node]\n" + + "driver-implementation = " + + browser.displayName() + + "\n" + + "session-timeout = 11" + + "\n" + + "enable-managed-downloads = true"))); tearDowns.add(deployment); server = deployment.getServer(); @@ -106,7 +118,10 @@ void multipleSimultaneousSessions() throws Exception { try { WebDriver driver = RemoteWebDriver.builder() - .oneOf(browser.getCapabilities()) + .oneOf( + browser + .getCapabilities() + .merge(new MutableCapabilities(Map.of("se:downloadsEnabled", true)))) .address(server.getUrl()) .build(); @@ -124,4 +139,44 @@ void multipleSimultaneousSessions() throws Exception { CompletableFuture.allOf(futures).get(4, MINUTES); } + + @Test + void multipleSimultaneousSessionsTimedOut() throws Exception { + assertThat(server.isStarted()).isTrue(); + + CompletableFuture[] futures = new CompletableFuture[10]; + for (int i = 0; i < futures.length; i++) { + CompletableFuture future = new CompletableFuture<>(); + futures[i] = future; + + executor.submit( + () -> { + try { + WebDriver driver = + RemoteWebDriver.builder() + .oneOf(browser.getCapabilities()) + .address(server.getUrl()) + .build(); + driver.get(appServer.getUrl().toString()); + Thread.sleep(11000); + NoSuchSessionException exception = + assertThrows(NoSuchSessionException.class, driver::getTitle); + assertTrue(exception.getMessage().startsWith("Cannot find session with id:")); + WebDriverException webDriverException = + assertThrows( + WebDriverException.class, + () -> ((RemoteWebDriver) driver).getDownloadableFiles()); + assertTrue( + webDriverException + .getMessage() + .startsWith("Cannot find downloads file system for session id:")); + future.complete(true); + } catch (Exception e) { + future.completeExceptionally(e); + } + }); + } + + CompletableFuture.allOf(futures).get(5, MINUTES); + } }