diff --git a/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java b/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java index f86f8d264be1f..409781892bb53 100644 --- a/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java +++ b/java/src/org/openqa/selenium/grid/distributor/local/LocalDistributor.java @@ -102,6 +102,7 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReadWriteLock; import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -753,7 +754,7 @@ public void run() { // starting the session, we just put the request back on the queue. // This does mean, however, that under high contention, we might end // up starving a session request. - Set stereotypes = + Map stereotypes = getAvailableNodes() .stream() .filter(NodeStatus::hasCapacity) @@ -763,15 +764,15 @@ public void run() { .getSlots() .stream() .map(Slot::getStereotype) - .collect(Collectors.toSet())) + .collect(Collectors.toList())) .flatMap(Collection::stream) - .collect(Collectors.toSet()); + .collect(Collectors.groupingBy(ImmutableCapabilities::new, Collectors.counting())); if (!stereotypes.isEmpty()) { - Optional maybeRequest = sessionQueue.getNextAvailable(stereotypes); - maybeRequest.ifPresent( + List matchingRequests = sessionQueue.getNextAvailable(stereotypes); + matchingRequests.forEach( req -> sessionCreatorExecutor.execute(() -> handleNewSessionRequest(req))); - loop = maybeRequest.isPresent(); + loop = !matchingRequests.isEmpty(); } else { loop = false; } diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/GetNextMatchingRequest.java b/java/src/org/openqa/selenium/grid/sessionqueue/GetNextMatchingRequest.java index 992c5bf81d921..ea35f1b3d7eaa 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/GetNextMatchingRequest.java +++ b/java/src/org/openqa/selenium/grid/sessionqueue/GetNextMatchingRequest.java @@ -20,6 +20,7 @@ import org.openqa.selenium.Capabilities; import org.openqa.selenium.grid.data.SessionRequest; import org.openqa.selenium.internal.Require; +import org.openqa.selenium.json.Json; import org.openqa.selenium.json.TypeToken; import org.openqa.selenium.remote.http.Contents; import org.openqa.selenium.remote.http.HttpHandler; @@ -30,8 +31,9 @@ import java.io.UncheckedIOException; import java.lang.reflect.Type; -import java.util.Optional; -import java.util.Set; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import static java.util.Collections.singletonMap; import static org.openqa.selenium.remote.tracing.HttpTracing.newSpanAsChildOf; @@ -39,7 +41,8 @@ import static org.openqa.selenium.remote.tracing.Tags.HTTP_RESPONSE; class GetNextMatchingRequest implements HttpHandler { - private static final Type SET_OF_CAPABILITIES = new TypeToken>() {}.getType(); + private static final Type MAP_OF_CAPABILITIES = new TypeToken>() {}.getType(); + private static final Json JSON = new Json(); private final Tracer tracer; private final NewSessionQueue queue; @@ -53,11 +56,18 @@ public GetNextMatchingRequest(Tracer tracer, NewSessionQueue queue) { public HttpResponse execute(HttpRequest req) throws UncheckedIOException { try (Span span = newSpanAsChildOf(tracer, req, "sessionqueue.getrequest")) { HTTP_REQUEST.accept(span, req); - Set stereotypes = Contents.fromJson(req, SET_OF_CAPABILITIES); + Map stereotypesJson = Contents.fromJson(req, MAP_OF_CAPABILITIES); - Optional maybeRequest = queue.getNextAvailable(stereotypes); + Map stereotypes = new HashMap<>(); - HttpResponse response = new HttpResponse().setContent(Contents.asJson(singletonMap("value", maybeRequest.orElse(null)))); + stereotypesJson.forEach((k,v) -> { + Capabilities caps = JSON.toType(k, Capabilities.class); + stereotypes.put(caps, v); + }); + + List sessionRequestList = queue.getNextAvailable(stereotypes); + + HttpResponse response = new HttpResponse().setContent(Contents.asJson(singletonMap("value", sessionRequestList))); HTTP_RESPONSE.accept(span, response); diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/NewSessionQueue.java b/java/src/org/openqa/selenium/grid/sessionqueue/NewSessionQueue.java index 3cd6081b520bf..d5ac1ee5d8a88 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/NewSessionQueue.java +++ b/java/src/org/openqa/selenium/grid/sessionqueue/NewSessionQueue.java @@ -105,7 +105,7 @@ private RequestId requestIdFrom(Map params) { public abstract Optional remove(RequestId reqId); - public abstract Optional getNextAvailable(Set stereotypes); + public abstract List getNextAvailable(Map stereotypes); public abstract void complete(RequestId reqId, Either result); diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueue.java b/java/src/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueue.java index e9dbfcd5d8124..a3b41120e0c85 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueue.java +++ b/java/src/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueue.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableSet; import org.openqa.selenium.Capabilities; +import org.openqa.selenium.ImmutableCapabilities; import org.openqa.selenium.SessionNotCreatedException; import org.openqa.selenium.concurrent.GuardedRunnable; import org.openqa.selenium.grid.config.Config; @@ -295,23 +296,33 @@ public Optional remove(RequestId reqId) { } @Override - public Optional getNextAvailable(Set stereotypes) { + public List getNextAvailable(Map stereotypes) { Require.nonNull("Stereotypes", stereotypes); Predicate matchesStereotype = - caps -> stereotypes.stream().anyMatch(stereotype -> slotMatcher.matches(stereotype, caps)); + caps -> stereotypes.entrySet() + .stream() + .filter(entry -> entry.getValue() > 0) + .anyMatch(entry -> { + boolean matches = slotMatcher.matches(entry.getKey(), caps); + if (matches) { + Long value = entry.getValue(); + entry.setValue(value - 1); + } + return matches; + }); Lock writeLock = lock.writeLock(); writeLock.lock(); try { - Optional maybeRequest = - queue.stream() - .filter(req -> req.getDesiredCapabilities().stream().anyMatch(matchesStereotype)) - .findFirst(); + List availableRequests = queue.stream() + .filter(req -> req.getDesiredCapabilities().stream().anyMatch(matchesStereotype)) + .limit(10) // TODO: Batch size should be configurable via a flag + .collect(Collectors.toList()); - maybeRequest.ifPresent(req -> this.remove(req.getRequestId())); + availableRequests.forEach(req -> this.remove(req.getRequestId())); - return maybeRequest; + return availableRequests; } finally { writeLock.unlock(); } diff --git a/java/src/org/openqa/selenium/grid/sessionqueue/remote/RemoteNewSessionQueue.java b/java/src/org/openqa/selenium/grid/sessionqueue/remote/RemoteNewSessionQueue.java index 832503c2606d3..81fd8650f2feb 100644 --- a/java/src/org/openqa/selenium/grid/sessionqueue/remote/RemoteNewSessionQueue.java +++ b/java/src/org/openqa/selenium/grid/sessionqueue/remote/RemoteNewSessionQueue.java @@ -48,9 +48,10 @@ import java.lang.reflect.Type; import java.net.MalformedURLException; import java.net.URI; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; -import java.util.Set; import static org.openqa.selenium.remote.http.HttpMethod.DELETE; import static org.openqa.selenium.remote.http.HttpMethod.GET; @@ -59,6 +60,7 @@ public class RemoteNewSessionQueue extends NewSessionQueue { private static final Type QUEUE_CONTENTS_TYPE = new TypeToken>() {}.getType(); + private static final Type SESSION_REQUEST_TYPE = new TypeToken>() {}.getType(); private static final Json JSON = new Json(); private final HttpClient client; private final Filter addSecret; @@ -128,17 +130,19 @@ public Optional remove(RequestId reqId) { } @Override - public Optional getNextAvailable(Set stereotypes) { + public List getNextAvailable(Map stereotypes) { Require.nonNull("Stereotypes", stereotypes); + Map stereotypeJson = new HashMap<>(); + stereotypes.forEach((k,v) -> stereotypeJson.put(JSON.toJson(k), v)); + HttpRequest upstream = new HttpRequest(POST, "/se/grid/newsessionqueue/session/next") - .setContent(Contents.asJson(stereotypes)); + .setContent(Contents.asJson(stereotypeJson)); + HttpTracing.inject(tracer, tracer.getCurrentContext(), upstream); HttpResponse response = client.with(addSecret).execute(upstream); - SessionRequest value = Values.get(response, SessionRequest.class); - - return Optional.ofNullable(value); + return Values.get(response, SESSION_REQUEST_TYPE); } @Override diff --git a/java/test/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueueTest.java b/java/test/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueueTest.java index 98a0887691b48..e397219464b7f 100644 --- a/java/test/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueueTest.java +++ b/java/test/org/openqa/selenium/grid/sessionqueue/local/LocalNewSessionQueueTest.java @@ -51,6 +51,7 @@ import java.net.URISyntaxException; import java.time.Duration; import java.time.Instant; +import java.util.HashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -520,10 +521,57 @@ void shouldBeAbleToReturnTheNextAvailableEntryThatMatchesAStereotype(Supplier returned = queue.getNextAvailable( - Set.of(new ImmutableCapabilities("browserName", "cheese"))); + Map stereotypes = new HashMap<>(); + stereotypes.put(new ImmutableCapabilities("browserName", "cheese"), 1L); - assertThat(returned).isEqualTo(Optional.of(expected)); + List returned = queue.getNextAvailable(stereotypes); + + assertThat(returned.get(0)).isEqualTo(expected); + } + + @ParameterizedTest + @MethodSource("data") + void shouldBeAbleToReturnTheNextAvailableBatchThatMatchesStereotypes(Supplier supplier) { + setup(supplier); + + SessionRequest firstSessionRequest = new SessionRequest( + new RequestId(UUID.randomUUID()), + Instant.now(), + Set.of(W3C), + Set.of(new ImmutableCapabilities("browserName", "cheese", "se:kind", "smoked")), + Map.of(), + Map.of()); + + SessionRequest secondSessionRequest = new SessionRequest( + new RequestId(UUID.randomUUID()), + Instant.now(), + Set.of(W3C), + Set.of(new ImmutableCapabilities("browserName", "peas", "se:kind", "smoked")), + Map.of(), + Map.of()); + + SessionRequest thirdSessionRequest = new SessionRequest( + new RequestId(UUID.randomUUID()), + Instant.now(), + Set.of(W3C), + Set.of(new ImmutableCapabilities("browserName", "peas", "se:kind", "smoked")), + Map.of(), + Map.of()); + + localQueue.injectIntoQueue(firstSessionRequest); + localQueue.injectIntoQueue(secondSessionRequest); + localQueue.injectIntoQueue(thirdSessionRequest); + + Map stereotypes = new HashMap<>(); + stereotypes.put(new ImmutableCapabilities("browserName", "cheese"), 2L); + stereotypes.put(new ImmutableCapabilities("browserName", "peas"), 2L); + + List returned = queue.getNextAvailable(stereotypes); + + assertThat(returned.size()).isEqualTo(3); + assertTrue(returned.contains(firstSessionRequest)); + assertTrue(returned.contains(secondSessionRequest)); + assertTrue(returned.contains(thirdSessionRequest)); } @ParameterizedTest @@ -551,10 +599,12 @@ void shouldNotReturnANextAvailableEntryThatDoesNotMatchTheStereotypes(Supplier returned = queue.getNextAvailable( - Set.of(new ImmutableCapabilities("browserName", "cheese"))); + Map stereotypes = new HashMap<>(); + stereotypes.put(new ImmutableCapabilities("browserName", "cheese"), 1L); + + List returned = queue.getNextAvailable(stereotypes); - assertThat(returned).isEqualTo(Optional.of(expected)); + assertThat(returned.get(0)).isEqualTo(expected); } static class TestData {