Skip to content

Commit

Permalink
[py] implement file downloads (#13023)
Browse files Browse the repository at this point in the history
* [py] download file do not return a string

* [py] require enabling downloads
  • Loading branch information
titusfortner committed Nov 1, 2023
1 parent 96f13f8 commit 605fccd
Show file tree
Hide file tree
Showing 6 changed files with 139 additions and 6 deletions.
15 changes: 14 additions & 1 deletion py/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def fin():
options = get_options(driver_class, request.config)
if driver_class == "Remote":
options = get_options("Firefox", request.config) or webdriver.FirefoxOptions()
options.set_capability("moz:firefoxOptions", {})
options.enable_downloads = True
if driver_class == "WebKitGTK":
options = get_options(driver_class, request.config)
if driver_class == "Edge":
Expand Down Expand Up @@ -246,7 +248,18 @@ def wait_for_server(url, timeout):
except Exception:
print("Starting the Selenium server")
process = subprocess.Popen(
["java", "-jar", _path, "standalone", "--port", "4444", "--selenium-manager", "true"]
[
"java",
"-jar",
_path,
"standalone",
"--port",
"4444",
"--selenium-manager",
"true",
"--enable-managed-downloads",
"true",
]
)
print(f"Selenium server running as process: {process.pid}")
assert wait_for_server(url, 10), f"Timed out waiting for Selenium server at {url}"
Expand Down
24 changes: 23 additions & 1 deletion py/selenium/webdriver/common/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, name):
self.name = name

def __get__(self, obj, cls):
if self.name in ("acceptInsecureCerts", "strictFileInteractability", "setWindowRect"):
if self.name in ("acceptInsecureCerts", "strictFileInteractability", "setWindowRect", "se:downloadsEnabled"):
return obj._caps.get(self.name, False)
return obj._caps.get(self.name)

Expand Down Expand Up @@ -322,6 +322,28 @@ class BaseOptions(metaclass=ABCMeta):
- `None`
"""

enable_downloads = _BaseOptionsDescriptor("se:downloadsEnabled")
"""Gets and Sets whether session can download files.
Usage
-----
- Get
- `self.enable_downloads`
- Set
- `self.enable_downloads` = `value`
Parameters
----------
`value`: `bool`
Returns
-------
- Get
- `bool`
- Set
- `None`
"""

def __init__(self) -> None:
super().__init__()
self._caps = self.default_capabilities
Expand Down
9 changes: 6 additions & 3 deletions py/selenium/webdriver/remote/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ class Command:
https://w3c.github.io/webdriver/
"""

# Keep in sync with org.openqa.selenium.remote.DriverCommand

NEW_SESSION: str = "newSession"
DELETE_SESSION: str = "deleteSession"
NEW_WINDOW: str = "newWindow"
Expand All @@ -49,7 +47,6 @@ class Command:
CLEAR_ELEMENT: str = "clearElement"
CLICK_ELEMENT: str = "clickElement"
SEND_KEYS_TO_ELEMENT: str = "sendKeysToElement"
UPLOAD_FILE: str = "uploadFile"
W3C_GET_CURRENT_WINDOW_HANDLE: str = "w3cGetCurrentWindowHandle"
W3C_GET_WINDOW_HANDLES: str = "w3cGetWindowHandles"
SET_WINDOW_RECT: str = "setWindowRect"
Expand Down Expand Up @@ -119,3 +116,9 @@ class Command:
REMOVE_CREDENTIAL: str = "removeCredential"
REMOVE_ALL_CREDENTIALS: str = "removeAllCredentials"
SET_USER_VERIFIED: str = "setUserVerified"

# Remote File Management
UPLOAD_FILE: str = "uploadFile"
GET_DOWNLOADABLE_FILES: str = "getDownloadableFiles"
DOWNLOAD_FILE: str = "downloadFile"
DELETE_DOWNLOADABLE_FILES: str = "deleteDownloadableFiles"
5 changes: 4 additions & 1 deletion py/selenium/webdriver/remote/remote_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
Command.CLEAR_ELEMENT: ("POST", "/session/$sessionId/element/$id/clear"),
Command.GET_ELEMENT_TEXT: ("GET", "/session/$sessionId/element/$id/text"),
Command.SEND_KEYS_TO_ELEMENT: ("POST", "/session/$sessionId/element/$id/value"),
Command.UPLOAD_FILE: ("POST", "/session/$sessionId/se/file"),
Command.GET_ELEMENT_TAG_NAME: ("GET", "/session/$sessionId/element/$id/name"),
Command.IS_ELEMENT_SELECTED: ("GET", "/session/$sessionId/element/$id/selected"),
Command.IS_ELEMENT_ENABLED: ("GET", "/session/$sessionId/element/$id/enabled"),
Expand Down Expand Up @@ -122,6 +121,10 @@
"/session/$sessionId/webauthn/authenticator/$authenticatorId/credentials",
),
Command.SET_USER_VERIFIED: ("POST", "/session/$sessionId/webauthn/authenticator/$authenticatorId/uv"),
Command.UPLOAD_FILE: ("POST", "/session/$sessionId/se/file"),
Command.GET_DOWNLOADABLE_FILES: ("GET", "/session/$sessionId/se/files"),
Command.DOWNLOAD_FILE: ("POST", "/session/$sessionId/se/files"),
Command.DELETE_DOWNLOADABLE_FILES: ("DELETE", "/session/$sessionId/se/files"),
}


Expand Down
36 changes: 36 additions & 0 deletions py/selenium/webdriver/remote/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""The WebDriver implementation."""
import base64
import contextlib
import copy
import os
import pkgutil
import types
import typing
Expand Down Expand Up @@ -1132,3 +1134,37 @@ def set_user_verified(self, verified: bool) -> None:
verified: True if the authenticator will pass user verification, False otherwise.
"""
self.execute(Command.SET_USER_VERIFIED, {"authenticatorId": self._authenticator_id, "isUserVerified": verified})

def get_downloadable_files(self) -> dict:
"""Retrieves the downloadable files as a map of file names and their
corresponding URLs."""
if "se:downloadsEnabled" not in self.capabilities:
raise WebDriverException("You must enable downloads in order to work with downloadable files.")

return self.execute(Command.GET_DOWNLOADABLE_FILES)["value"]["names"]

def download_file(self, file_name: str, target_directory: str) -> None:
"""Downloads a file with the specified file name to the target
directory.
file_name: The name of the file to download.
target_directory: The path to the directory to save the downloaded file.
"""
if "se:downloadsEnabled" not in self.capabilities:
raise WebDriverException("You must enable downloads in order to work with downloadable files.")

if not os.path.exists(target_directory):
os.makedirs(target_directory)

contents = self.execute(Command.DOWNLOAD_FILE, {"name": file_name})["value"]["contents"]

target_file = os.path.join(target_directory, file_name)
with open(target_file, "wb") as file:
file.write(base64.b64decode(contents))

def delete_downloadable_files(self) -> None:
"""Deletes all downloadable files."""
if "se:downloadsEnabled" not in self.capabilities:
raise WebDriverException("You must enable downloads in order to work with downloadable files.")

self.execute(Command.DELETE_DOWNLOADABLE_FILES)
56 changes: 56 additions & 0 deletions py/test/selenium/webdriver/remote/remote_downloads_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import tempfile

from selenium.webdriver.common.by import By
from selenium.webdriver.support.wait import WebDriverWait


def test_get_downloadable_files(driver, pages):
_browser_downloads(driver, pages)

file_names = driver.get_downloadable_files()

assert "file_1.txt" in file_names
assert "file_2.jpg" in file_names


def test_download_file(driver, pages):
_browser_downloads(driver, pages)

file_name = driver.get_downloadable_files()[0]
with tempfile.TemporaryDirectory() as target_directory:
driver.download_file(file_name, target_directory)

target_file = os.path.join(target_directory, file_name)
with open(target_file, "r") as file:
assert "Hello, World!" in file.read()


def test_delete_downloadable_files(driver, pages):
_browser_downloads(driver, pages)

driver.delete_downloadable_files()
assert not driver.get_downloadable_files()


def _browser_downloads(driver, pages):
pages.load("downloads/download.html")
driver.find_element(By.ID, "file-1").click()
driver.find_element(By.ID, "file-2").click()
WebDriverWait(driver, 3).until(lambda d: "file_2.jpg" in d.get_downloadable_files())

0 comments on commit 605fccd

Please sign in to comment.