diff --git a/java/src/org/openqa/selenium/remote/Augmenter.java b/java/src/org/openqa/selenium/remote/Augmenter.java index cda080b1c756a..d64f25a0ff055 100644 --- a/java/src/org/openqa/selenium/remote/Augmenter.java +++ b/java/src/org/openqa/selenium/remote/Augmenter.java @@ -50,6 +50,7 @@ import org.openqa.selenium.internal.Require; import org.openqa.selenium.logging.HasLogEvents; import org.openqa.selenium.remote.html5.AddWebStorage; +import org.openqa.selenium.support.decorators.Decorated; /** * Enhance the interfaces implemented by an instance of the {@link org.openqa.selenium.WebDriver} @@ -251,6 +252,10 @@ private RemoteWebDriver extractRemoteWebDriver(WebDriver driver) { return (RemoteWebDriver) driver; } + if (driver instanceof Decorated) { + return extractRemoteWebDriver((WebDriver) ((Decorated) driver).getOriginal()); + } + if (driver instanceof WrapsDriver) { return extractRemoteWebDriver(((WrapsDriver) driver).getWrappedDriver()); } diff --git a/java/src/org/openqa/selenium/remote/BUILD.bazel b/java/src/org/openqa/selenium/remote/BUILD.bazel index e93f0b4643667..c231530f0b839 100644 --- a/java/src/org/openqa/selenium/remote/BUILD.bazel +++ b/java/src/org/openqa/selenium/remote/BUILD.bazel @@ -60,6 +60,7 @@ java_library( "//java/src/org/openqa/selenium/remote/http/netty", "//java/src/org/openqa/selenium/remote/tracing", "//java/src/org/openqa/selenium/remote/tracing/opentelemetry", + "//java/src/org/openqa/selenium/support/decorators", artifact("com.google.guava:guava"), artifact("net.bytebuddy:byte-buddy"), ], diff --git a/java/src/org/openqa/selenium/support/BUILD.bazel b/java/src/org/openqa/selenium/support/BUILD.bazel index a0477e2a6abfc..c378ca0b5c033 100644 --- a/java/src/org/openqa/selenium/support/BUILD.bazel +++ b/java/src/org/openqa/selenium/support/BUILD.bazel @@ -16,6 +16,7 @@ java_export( visibility = ["//visibility:public"], exports = [ ":page-factory", + "//java/src/org/openqa/selenium/support/decorators", "//java/src/org/openqa/selenium/support/events", "//java/src/org/openqa/selenium/support/locators", "//java/src/org/openqa/selenium/support/ui:clock", diff --git a/java/src/org/openqa/selenium/support/decorators/BUILD.bazel b/java/src/org/openqa/selenium/support/decorators/BUILD.bazel index 1166e43b9c23f..5ad1eb0cd9fd5 100644 --- a/java/src/org/openqa/selenium/support/decorators/BUILD.bazel +++ b/java/src/org/openqa/selenium/support/decorators/BUILD.bazel @@ -5,6 +5,7 @@ java_library( name = "decorators", srcs = glob(["*.java"]), visibility = [ + "//java/src/org/openqa/selenium/remote:__subpackages__", "//java/src/org/openqa/selenium/support:__subpackages__", "//java/test/org/openqa/selenium/support/decorators:__pkg__", ], diff --git a/java/test/org/openqa/selenium/remote/AugmenterTest.java b/java/test/org/openqa/selenium/remote/AugmenterTest.java index 2928f547d4afc..82a191705cd47 100644 --- a/java/test/org/openqa/selenium/remote/AugmenterTest.java +++ b/java/test/org/openqa/selenium/remote/AugmenterTest.java @@ -44,6 +44,8 @@ import org.openqa.selenium.internal.Require; import org.openqa.selenium.support.decorators.Decorated; import org.openqa.selenium.support.decorators.WebDriverDecorator; +import org.openqa.selenium.support.events.EventFiringDecorator; +import org.openqa.selenium.support.events.WebDriverListener; @Tag("UnitTests") class AugmenterTest { @@ -240,6 +242,40 @@ void shouldDecorateAugmentedWebDriver() { assertThat(number).isEqualTo(42); } + @Test + void shouldAugmentDecoratedWebDriver() { + final Capabilities caps = + new ImmutableCapabilities( + "magic.numbers", true, + "numbers", true); + WebDriver driver = new RemoteWebDriver(new StubExecutor(caps), caps); + WebDriver eventFiringDecorate = + new EventFiringDecorator<>( + new WebDriverListener() { + @Override + public void beforeAnyCall(Object target, Method method, Object[] args) { + System.out.println("Bazinga!"); + } + }) + .decorate(driver); + + WebDriver modifyTitleDecorate = + new ModifyTitleWebDriverDecorator().decorate(eventFiringDecorate); + + WebDriver augmented = + getAugmenter() + .addDriverAugmentation("magic.numbers", HasMagicNumbers.class, (c, exe) -> () -> 42) + .augment(modifyTitleDecorate); + + assertThat(modifyTitleDecorate).isNotSameAs(driver); + + assertThat(((HasMagicNumbers) augmented).getMagicNumber()).isEqualTo(42); + assertThat(augmented.getTitle()).isEqualTo("title"); + + assertThat(augmented).isNotSameAs(modifyTitleDecorate); + assertThat(augmented).isInstanceOf(Decorated.class); + } + private static class ByMagic extends By { private final String magicWord;