diff --git a/pyproject.toml b/pyproject.toml index 56a23c11..e8810515 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,13 +27,12 @@ dependencies = [ "pooch>=1.7.0", # for scipy.datasets "docutils>=0.15", - "requests>=2.28.2", + "httpx>=0.23.0", "beautifulsoup4>=4.4", "lxml>=4.5.0", "pyyaml>5.1", "more-itertools>=9.0", "tqdm>=4.66.1", - "requests-cache>1.0", "synphot>=1.1.0", "skycalc_ipy>=0.2.0", diff --git a/scopesim/commands/user_commands.py b/scopesim/commands/user_commands.py index 9815dbbe..2d40336f 100644 --- a/scopesim/commands/user_commands.py +++ b/scopesim/commands/user_commands.py @@ -5,7 +5,7 @@ import numpy as np import yaml -import requests +import httpx from .. import rc from ..utils import find_file, top_level_catch @@ -290,7 +290,7 @@ def check_for_updates(package_name): front_matter = rc.__currsys__["!SIM.file.server_base_url"] back_matter = f"api.php?package_name={package_name}" try: - response = requests.get(url=front_matter+back_matter).json() + response = httpx.get(url=front_matter+back_matter).json() except: print(f"Offline. Cannot check for updates for {package_name}") return response diff --git a/scopesim/server/database.py b/scopesim/server/database.py index 32165b98..6c8e346f 100644 --- a/scopesim/server/database.py +++ b/scopesim/server/database.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- -""" -Functions to download instrument packages and example data -""" -import re +"""Functions to download instrument packages and example data.""" + import logging from datetime import date from warnings import warn @@ -12,94 +10,58 @@ # from collections.abc import Iterator, Iterable, Mapping from typing import Iterator, Iterable, Mapping -from urllib.error import HTTPError -from urllib3.exceptions import HTTPError as HTTPError3 from more_itertools import first, last, groupby_transform -import requests -from requests.packages.urllib3.util.retry import Retry -from requests.adapters import HTTPAdapter -import bs4 - from scopesim import rc from .github_utils import download_github_folder from .example_data_utils import (download_example_data, list_example_data, get_server_elements) -from .download_utils import initiate_download, handle_download, handle_unzipping +from .download_utils import (get_server_folder_contents, handle_download, + handle_unzipping, create_client, ServerError) _GrpVerType = Mapping[str, Iterable[str]] _GrpItrType = Iterator[Tuple[str, List[str]]] -HTTP_RETRY_CODES = [403, 404, 429, 500, 501, 502, 503] +class PkgNotFoundError(Exception): + """Unable to find given package or given release of that package.""" -class ServerError(Exception): - """Some error with the server or connection to the server.""" +def get_base_url(): + """Get instrument package server URL from rc.__config__.""" + return rc.__config__["!SIM.file.server_base_url"] -class PkgNotFoundError(Exception): - """Unable to find given package or given release of that package.""" def get_server_package_list(): warn("Function Depreciated", DeprecationWarning, stacklevel=2) # Emulate legacy API without using the problematic yaml file - folders = list(dict(crawl_server_dirs()).keys()) - pkgs_dict = {} - for dir_name in folders: - p_list = [_parse_package_version(package) for package - in get_server_folder_contents(dir_name)] - grouped = dict(group_package_versions(p_list)) - for p_name in grouped: - p_dict = { - "latest": _unparse_raw_version(get_latest(grouped[p_name]), - p_name).strip(".zip"), - "path": dir_name.strip("/"), - "stable": _unparse_raw_version(get_stable(grouped[p_name]), - p_name).strip(".zip"), - } - pkgs_dict[p_name] = p_dict + with create_client(get_base_url()) as client: + folders = list(dict(crawl_server_dirs(client)).keys()) + pkgs_dict = {} + for dir_name in folders: + p_list = [_parse_package_version(package) for package + in get_server_folder_contents(client, dir_name)] + grouped = dict(group_package_versions(p_list)) + for p_name in grouped: + p_dict = { + "latest": _unparse_raw_version(get_latest(grouped[p_name]), + p_name).strip(".zip"), + "path": dir_name.strip("/"), + "stable": _unparse_raw_version(get_stable(grouped[p_name]), + p_name).strip(".zip"), + } + pkgs_dict[p_name] = p_dict return pkgs_dict -def get_server_folder_contents(dir_name: str, - unique_str: str = ".zip$") -> Iterator[str]: - url = rc.__config__["!SIM.file.server_base_url"] + dir_name - - retry_strategy = Retry(total=2, - status_forcelist=HTTP_RETRY_CODES, - allowed_methods=["GET"]) - adapter = HTTPAdapter(max_retries=retry_strategy) - - try: - with requests.Session() as session: - session.mount("https://", adapter) - result = session.get(url).content - except (requests.exceptions.ConnectionError, - requests.exceptions.RetryError) as error: - logging.error(error) - raise ServerError("Cannot connect to server. " - f"Attempted URL was: {url}.") from error - except Exception as error: - logging.error(("Unhandled exception occured while accessing server." - "Attempted URL was: %s."), url) - logging.error(error) - raise error - - soup = bs4.BeautifulSoup(result, features="lxml") - hrefs = soup.find_all("a", href=True, string=re.compile(unique_str)) - pkgs = (href.string for href in hrefs) - - return pkgs - - def _get_package_name(package: str) -> str: return package.split(".", maxsplit=1)[0] def _parse_raw_version(raw_version: str) -> str: - """Catch initial package version which has no date info + """Catch initial package version which has no date info. Set initial package version to basically "minus infinity". """ @@ -109,8 +71,8 @@ def _parse_raw_version(raw_version: str) -> str: def _unparse_raw_version(raw_version: str, package_name: str) -> str: - """Turn version string back into full zip folder name - + """Turn version string back into full zip folder name. + If initial version was set with `_parse_raw_version`, revert that. """ if raw_version == str(date(1, 1, 1)): @@ -176,7 +138,7 @@ def get_all_latest(version_groups: _GrpVerType) -> Iterator[Tuple[str, str]]: def group_package_versions(all_packages: Iterable[Tuple[str, str]]) -> _GrpItrType: - """Group different versions of packages by package name""" + """Group different versions of packages by package name.""" version_groups = groupby_transform(sorted(all_packages), keyfunc=first, valuefunc=last, @@ -184,12 +146,17 @@ def group_package_versions(all_packages: Iterable[Tuple[str, str]]) -> _GrpItrTy return version_groups -def crawl_server_dirs() -> Iterator[Tuple[str, Set[str]]]: - """Search all folders on server for .zip files""" - for dir_name in get_server_folder_contents("", "/"): +def crawl_server_dirs(client=None) -> Iterator[Tuple[str, Set[str]]]: + """Search all folders on server for .zip files.""" + if client is None: + with create_client(get_base_url()) as client: + yield from crawl_server_dirs(client) + return + + for dir_name in get_server_folder_contents(client, "", "/"): logging.info("Searching folder '%s'", dir_name) try: - p_dir = get_server_folder_package_names(dir_name) + p_dir = get_server_folder_package_names(client, dir_name) except ValueError as err: logging.info(err) continue @@ -197,30 +164,38 @@ def crawl_server_dirs() -> Iterator[Tuple[str, Set[str]]]: yield dir_name, p_dir -def get_all_package_versions() -> Dict[str, List[str]]: - """Gather all versions for all packages present in any folder on server""" +def get_all_package_versions(client=None) -> Dict[str, List[str]]: + """Gather all versions for all packages present in any folder on server.""" + if client is None: + with create_client(get_base_url()) as client: + return get_all_package_versions(client) + grouped = {} - folders = list(dict(crawl_server_dirs()).keys()) + folders = list(dict(crawl_server_dirs(client)).keys()) for dir_name in folders: p_list = [_parse_package_version(package) for package - in get_server_folder_contents(dir_name)] + in get_server_folder_contents(client, dir_name)] grouped.update(group_package_versions(p_list)) return grouped -def get_package_folders() -> Dict[str, str]: - folder_dict = {pkg: path.strip("/") - for path, pkgs in dict(crawl_server_dirs()).items() - for pkg in pkgs} - return folder_dict +def get_package_folders(client) -> Dict[str, str]: + """Map package names to server locations.""" + folders_dict = {pkg: path.strip("/") + for path, pkgs in dict(crawl_server_dirs(client)).items() + for pkg in pkgs} + return folders_dict -def get_server_folder_package_names(dir_name: str) -> Set[str]: +def get_server_folder_package_names(client, dir_name: str) -> Set[str]: """ Retrieve all unique package names present on server in `dir_name` folder. Parameters ---------- + client : httpx.Client + Pre-existing httpx Client context manager. + dir_name : str Name of the folder on the server. @@ -236,7 +211,7 @@ def get_server_folder_package_names(dir_name: str) -> Set[str]: """ package_names = {package.split(".", maxsplit=1)[0] for package - in get_server_folder_contents(dir_name)} + in get_server_folder_contents(client, dir_name)} if not package_names: raise ValueError(f"No packages found in directory \"{dir_name}\".") @@ -264,15 +239,12 @@ def get_all_packages_on_server() -> Iterator[Tuple[str, set]]: Key-value pairs of folder and corresponding package names. """ - # TODO: this basically does the same as the crawl function... - for dir_name in ("locations", "telescopes", "instruments"): - package_names = get_server_folder_package_names(dir_name) - yield dir_name, package_names + yield from crawl_server_dirs() def list_packages(pkg_name: Optional[str] = None) -> List[str]: """ - List all packages, or all variants of a single package + List all packages, or all variants of a single package. Parameters ---------- @@ -303,7 +275,7 @@ def list_packages(pkg_name: Optional[str] = None) -> List[str]: all_stable = list(dict(get_all_stable(all_grouped)).keys()) return all_stable - if not pkg_name in all_grouped: + if pkg_name not in all_grouped: raise ValueError(f"Package name {pkg_name} not found on server.") p_versions = [_unparse_raw_version(version, pkg_name) @@ -327,18 +299,18 @@ def _get_zipname(pkg_name: str, release: str, all_versions) -> str: return _unparse_raw_version(zip_name, pkg_name) -def _download_single_package(pkg_name: str, release: str, all_versions, - folder_dict: Path, base_url: str, save_dir: Path, - padlen: int, from_cache: bool) -> Path: +def _download_single_package(client, pkg_name: str, release: str, all_versions, + folders_dict: Path, save_dir: Path, + padlen: int) -> Path: if pkg_name not in all_versions: maybe = "" - for key in folder_dict: + for key in folders_dict: if pkg_name in key or key in pkg_name: maybe = f"\nDid you mean '{key}' instead of '{pkg_name}'?" - raise PkgNotFoundError(f"Unable to find {release} release for " - f"package '{pkg_name}' on server {base_url}." - + maybe) + raise PkgNotFoundError( + f"Unable to find {release} release for package '{pkg_name}' on " + f"server {client.base_url!s}.{maybe}") if save_dir is None: save_dir = rc.__config__["!SIM.file.local_packages_path"] @@ -353,43 +325,20 @@ def _download_single_package(pkg_name: str, release: str, all_versions, return save_dir.absolute() zip_name = _get_zipname(pkg_name, release, all_versions) - pkg_url = f"{base_url}{folder_dict[pkg_name]}/{zip_name}" - - try: - if from_cache is None: - from_cache = rc.__config__["!SIM.file.use_cached_downloads"] - - response = initiate_download(pkg_url, from_cache, "test_cache") - save_path = save_dir / f"{pkg_name}.zip" - handle_download(response, save_path, pkg_name, padlen) - handle_unzipping(save_path, save_dir, pkg_name, padlen) - - except HTTPError3 as error: - logging.error(error) - msg = f"Unable to find file: {pkg_url + pkg_name}" - raise ValueError(msg) from error - except HTTPError as error: - logging.error("urllib (not urllib3) error was raised, this should " - "not happen anymore!") - logging.error(error) - except requests.exceptions.ConnectionError as error: - logging.error(error) - raise ServerError("Cannot connect to server.") from error - except Exception as error: - logging.error(("Unhandled exception occured while accessing server." - "Attempted URL was: %s."), base_url) - logging.error(error) - raise error + pkg_url = f"{folders_dict[pkg_name]}/{zip_name}" + + save_path = save_dir / f"{pkg_name}.zip" + handle_download(client, pkg_url, save_path, pkg_name, padlen) + handle_unzipping(save_path, save_dir, pkg_name, padlen) return save_path.absolute() def download_packages(pkg_names: Union[Iterable[str], str], release: str = "stable", - save_dir: Optional[str] = None, - from_cache: Optional[bool] = None) -> List[Path]: + save_dir: Optional[str] = None) -> List[Path]: """ - Download one or more packages to the local disk + Download one or more packages to the local disk. 1. Download stable, dev 2. Download specific version @@ -443,31 +392,32 @@ def download_packages(pkg_names: Union[Iterable[str], str], download_packages("ELT", release="github@dev_master") """ - base_url = rc.__config__["!SIM.file.server_base_url"] - + base_url = get_base_url() print("Gathering information from server ...") - - all_versions = get_all_package_versions() - folder_dict = get_package_folders() - - print("Connection successful, starting download ...") - - if isinstance(pkg_names, str): - pkg_names = [pkg_names] - - padlen = len(max(pkg_names, key=len)) - save_paths = [] - for pkg_name in pkg_names: - try: - pkg_path = _download_single_package(pkg_name, release, all_versions, - folder_dict, base_url, save_dir, - padlen, from_cache) - except PkgNotFoundError as error: - logging.error("\n") # needed until tqdm redirect is implemented - logging.error(error) - logging.error("Skipping download of package '%s'", pkg_name) - continue - save_paths.append(pkg_path) + logging.info("Accessing %s", base_url) + + with create_client(base_url) as client: + all_versions = get_all_package_versions(client) + folders_dict = get_package_folders(client) + + print("Connection successful, starting download ...") + + if isinstance(pkg_names, str): + pkg_names = [pkg_names] + + padlen = max(len(name) for name in pkg_names) + save_paths = [] + for pkg_name in pkg_names: + try: + pkg_path = _download_single_package( + client, pkg_name, release, all_versions, folders_dict, + save_dir, padlen) + except PkgNotFoundError as error: + logging.error("\n") # needed until tqdm redirect implemented + logging.error(error) + logging.error("Skipping download of package '%s'", pkg_name) + continue + save_paths.append(pkg_path) return save_paths @@ -513,5 +463,4 @@ def download_package(pkg_path, save_dir=None, url=None, from_cache=None): pkg_path = [pkg_path] pkg_names = [pkg.replace(".zip", "").split("/")[-1] for pkg in pkg_path] - return download_packages(pkg_names, release="stable", save_dir=save_dir, - from_cache=from_cache) + return download_packages(pkg_names, release="stable", save_dir=save_dir) diff --git a/scopesim/server/download_utils.py b/scopesim/server/download_utils.py index 61738ba0..15b47aab 100644 --- a/scopesim/server/download_utils.py +++ b/scopesim/server/download_utils.py @@ -1,23 +1,27 @@ # -*- coding: utf-8 -*- -""" -Used only by the `database` and `github_utils` submodules. -""" +"""Used only by the `database` and `github_utils` submodules.""" + +import re +import logging + +# Python 3.8 doesn't yet know these things....... +# from collections.abc import Iterator, Iterable, Mapping +from typing import Iterator from zipfile import ZipFile from pathlib import Path from shutil import get_terminal_size -import requests -from requests.packages.urllib3.util.retry import Retry -from requests.adapters import HTTPAdapter -from requests_cache import CachedSession +import httpx +import bs4 + from tqdm import tqdm # from tqdm.contrib.logging import logging_redirect_tqdm # put with logging_redirect_tqdm(loggers=all_loggers): around tqdm - -HTTP_RETRY_CODES = [403, 404, 429, 500, 501, 502, 503] +class ServerError(Exception): + """Some error with the server or connection to the server.""" def _make_tqdm_kwargs(desc: str = ""): @@ -31,42 +35,90 @@ def _make_tqdm_kwargs(desc: str = ""): return tqdm_kwargs -def _create_session(cached: bool = False, cache_name: str = ""): +def create_client(base_url, cached: bool = False, cache_name: str = ""): + """Create httpx Client instance, should support cache at some point.""" if cached: - return CachedSession(cache_name) - return requests.Session() - - -def initiate_download(pkg_url: str, - cached: bool = False, cache_name: str = "", - total: int = 5, backoff_factor: int = 2): - retry_strategy = Retry(total=total, backoff_factor=backoff_factor, - status_forcelist=HTTP_RETRY_CODES, - allowed_methods=["GET"]) - adapter = HTTPAdapter(max_retries=retry_strategy) - with _create_session(cached, cache_name) as session: - session.mount("https://", adapter) - response = session.get(pkg_url, stream=True) - return response + raise NotImplementedError("Caching not yet implemented with httpx.") + transport = httpx.HTTPTransport(retries=5) + client = httpx.Client(base_url=base_url, timeout=2, transport=transport) + return client -def handle_download(response, save_path: Path, pkg_name: str, +def handle_download(client, pkg_url: str, + save_path: Path, pkg_name: str, padlen: int, chunk_size: int = 128, disable_bar=False) -> None: + """Perform a streamed download and write the content to disk.""" tqdm_kwargs = _make_tqdm_kwargs(f"Downloading {pkg_name:<{padlen}}") - total = int(response.headers.get("content-length", 0)) - # Turn this into non-nested double with block in Python 3.9 or 10 (?) - with save_path.open("wb") as file_outer: - with tqdm.wrapattr(file_outer, "write", miniters=1, total=total, - **tqdm_kwargs, disable=disable_bar) as file_inner: - for chunk in response.iter_content(chunk_size=chunk_size): - file_inner.write(chunk) + + stream = send_get(client, pkg_url, stream=True) + + try: + with stream as response: + response.raise_for_status() + total = int(response.headers.get("Content-Length", 0)) + + # Turn this into non-nested double with block in Python 3.9 or 10 + with save_path.open("wb") as file_outer: + with tqdm.wrapattr(file_outer, "write", miniters=1, + total=total, **tqdm_kwargs, + disable=disable_bar) as file_inner: + for chunk in response.iter_bytes(chunk_size=chunk_size): + file_inner.write(chunk) + + except httpx.HTTPStatusError as err: + logging.error("Error response %s while requesting %s.", + err.response.status_code, err.request.url) + raise ServerError("Cannot connect to server.") from err + except Exception as err: + logging.exception("Unhandled exception while accessing server.") + raise ServerError("Cannot connect to server.") from err def handle_unzipping(save_path: Path, save_dir: Path, pkg_name: str, padlen: int) -> None: + """Unpack a zipped folder, usually called right after downloading.""" with ZipFile(save_path, "r") as zip_ref: namelist = zip_ref.namelist() tqdm_kwargs = _make_tqdm_kwargs(f"Extracting {pkg_name:<{padlen}}") for file in tqdm(iterable=namelist, total=len(namelist), **tqdm_kwargs): zip_ref.extract(file, save_dir) + + +def send_get(client, sub_url, stream: bool = False): + """Send a GET request (streamed or not) using an existing client. + + The point of this function is mostly elaborate exception handling. + """ + try: + if stream: + response = client.stream("GET", sub_url) + else: + response = client.get(sub_url) + response.raise_for_status() + except httpx.RequestError as err: + logging.exception("An error occurred while requesting %s.", + err.request.url) + raise ServerError("Cannot connect to server.") from err + except httpx.HTTPStatusError as err: + logging.error("Error response %s while requesting %s.", + err.response.status_code, err.request.url) + raise ServerError("Cannot connect to server.") from err + except Exception as err: + logging.exception("Unhandled exception while accessing server.") + raise ServerError("Cannot connect to server.") from err + + return response + + +def get_server_folder_contents(client, dir_name: str, + unique_str: str = ".zip$") -> Iterator[str]: + """Find all zip files in a given server folder.""" + dir_name = dir_name + "/" if not dir_name.endswith("/") else dir_name + response = send_get(client, dir_name) + + soup = bs4.BeautifulSoup(response.content, features="lxml") + hrefs = soup.find_all("a", href=True, string=re.compile(unique_str)) + pkgs = (href.string for href in hrefs) + + return pkgs diff --git a/scopesim/server/example_data_utils.py b/scopesim/server/example_data_utils.py index 86d1c33b..3bdf46f1 100644 --- a/scopesim/server/example_data_utils.py +++ b/scopesim/server/example_data_utils.py @@ -7,10 +7,7 @@ from pathlib import Path from typing import List, Optional, Union, Iterable -from urllib.error import HTTPError -from urllib3.exceptions import HTTPError as HTTPError3 - -import requests +import httpx import bs4 from astropy.utils.data import download_file @@ -40,7 +37,7 @@ def get_server_elements(url: str, unique_str: str = "/") -> List[str]: unique_str = [unique_str] try: - result = requests.get(url).content + result = httpx.get(url).content except Exception as error: raise ValueError(f"URL returned error: {url}") from error @@ -156,7 +153,7 @@ def download_example_data(file_path: Union[Iterable[str], str], cache=from_cache) save_path = save_dir / file_path.name file_path = shutil.copy2(cache_path, str(save_path)) - except (HTTPError, HTTPError3) as error: + except httpx.HTTPError as error: msg = f"Unable to find file: {url + 'example_data/' + file_path}" raise ValueError(msg) from error diff --git a/scopesim/server/github_utils.py b/scopesim/server/github_utils.py index f38a2d2d..2ded8ad8 100644 --- a/scopesim/server/github_utils.py +++ b/scopesim/server/github_utils.py @@ -15,23 +15,12 @@ from pathlib import Path from typing import Union -import requests -from requests.packages.urllib3.util.retry import Retry -from requests.adapters import HTTPAdapter - -from .download_utils import initiate_download, handle_download - - -HTTP_RETRY_CODES = [403, 404, 429, 500, 501, 502, 503] - - -class ServerError(Exception): - """Some error with the server or connection to the server.""" +from .download_utils import handle_download, send_get, create_client def create_github_url(url: str) -> None: """ - From the given url, produce a URL that is compatible with Github's REST API. + From the given url, produce a URL compatible with Github's REST API. Can handle blob or tree paths. """ @@ -40,7 +29,7 @@ def create_github_url(url: str) -> None: # Check if the given url is a url to a GitHub repo. If it is, tell the # user to use 'git clone' to download it - if re.match(repo_only_url,url): + if re.match(repo_only_url, url): message = ("✘ The given url is a complete repository. Use 'git clone'" " to download the repository") logging.error(message) @@ -57,7 +46,7 @@ def create_github_url(url: str) -> None: def download_github_folder(repo_url: str, output_dir: Union[Path, str] = "./") -> None: """ - Downloads the files and directories in repo_url. + Download the files and directories in repo_url. Re-written based on the on the download function `here `_ @@ -68,51 +57,22 @@ def download_github_folder(repo_url: str, api_url, download_dirs = create_github_url(repo_url) # get the contents of the github folder - try: - retry_strategy = Retry(total=3, backoff_factor=2, - status_forcelist=HTTP_RETRY_CODES, - allowed_methods=["GET"]) - adapter = HTTPAdapter(max_retries=retry_strategy) - with requests.Session() as session: - session.mount("https://", adapter) - data = session.get(api_url).json() - except (requests.exceptions.ConnectionError, - requests.exceptions.RetryError) as error: - logging.error(error) - raise ServerError("Cannot connect to server. " - f"Attempted URL was: {api_url}.") from error - except Exception as error: - logging.error(("Unhandled exception occured while accessing server." - "Attempted URL was: %s."), api_url) - logging.error(error) - raise error - - # Make the base directories for this GitHub folder - (output_dir / download_dirs).mkdir(parents=True, exist_ok=True) - - for entry in data: - # if the entry is a further folder, walk through it - if entry["type"] == "dir": - download_github_folder(repo_url=entry["html_url"], - output_dir=output_dir) - - # if the entry is a file, download it - elif entry["type"] == "file": - try: + with create_client("", cached=False) as client: + data = send_get(client, api_url).json() + + # Make the base directories for this GitHub folder + (output_dir / download_dirs).mkdir(parents=True, exist_ok=True) + + for entry in data: + # if the entry is a further folder, walk through it + if entry["type"] == "dir": + download_github_folder(repo_url=entry["html_url"], + output_dir=output_dir) + + # if the entry is a file, download it + elif entry["type"] == "file": # download the file save_path = output_dir / entry["path"] - response = initiate_download(entry["download_url"]) - handle_download(response, save_path, entry["path"], - padlen=0, disable_bar=True) + handle_download(client, entry["download_url"], save_path, + entry["path"], padlen=0, disable_bar=True) logging.info("Downloaded: %s", entry["path"]) - - except (requests.exceptions.ConnectionError, - requests.exceptions.RetryError) as error: - logging.error(error) - raise ServerError("Cannot connect to server. " - f"Attempted URL was: {api_url}.") from error - except Exception as error: - logging.error(("Unhandled exception occured while accessing " - "server. Attempted URL was: %s."), api_url) - logging.error(error) - raise error diff --git a/scopesim/tests/tests_server/test_database.py b/scopesim/tests/tests_server/test_database.py index 3d05e5bc..b5fe93e6 100644 --- a/scopesim/tests/tests_server/test_database.py +++ b/scopesim/tests/tests_server/test_database.py @@ -11,6 +11,13 @@ from scopesim import rc +@pytest.fixture(scope="class") +def mock_client(): + # TODO: investigate proper mocking via httpx + with db.create_client(db.get_base_url()) as client: + yield client + + @pytest.mark.webtest def test_package_list_loads(): with pytest.warns(DeprecationWarning): @@ -53,32 +60,30 @@ def test_throws_for_nonexisting_release(self): class TestGetServerFolderContents: @pytest.mark.webtest - def test_downloads_locations(self): - pkgs = list(db.get_server_folder_contents("locations")) + def test_downloads_locations(self, mock_client): + pkgs = list(db.get_server_folder_contents(mock_client, "locations")) assert len(pkgs) > 0 @pytest.mark.webtest - def test_downloads_telescopes(self): - pkgs = list(db.get_server_folder_contents("telescopes")) + def test_downloads_telescopes(self, mock_client): + pkgs = list(db.get_server_folder_contents(mock_client, "telescopes")) assert len(pkgs) > 0 @pytest.mark.webtest - def test_downloads_instruments(self): - pkgs = list(db.get_server_folder_contents("instruments")) + def test_downloads_instruments(self, mock_client): + pkgs = list(db.get_server_folder_contents(mock_client, "instruments")) assert len(pkgs) > 0 @pytest.mark.webtest - def test_finds_armazones(self): - pkgs = list(db.get_server_folder_contents("locations")) + def test_finds_armazones(self, mock_client): + pkgs = list(db.get_server_folder_contents(mock_client, "locations")) assert "Armazones" in pkgs[0] @pytest.mark.webtest def test_throws_for_wrong_url_server(self): - original_url = rc.__config__["!SIM.file.server_base_url"] - rc.__config__["!SIM.file.server_base_url"] = "https://scopesim.univie.ac.at/bogus/" - with pytest.raises(db.ServerError): - list(db.get_server_folder_contents("locations")) - rc.__config__["!SIM.file.server_base_url"] = original_url + with db.create_client("https://scopesim.univie.ac.at/bogus/") as client: + with pytest.raises(db.ServerError): + list(db.get_server_folder_contents(client, "locations")) class TestGetServerElements: @@ -146,7 +151,7 @@ class TestDownloadPackages: def test_downloads_stable_package(self): with TemporaryDirectory() as tmpdir: db.download_packages(["test_package"], release="stable", - save_dir=tmpdir, from_cache=False) + save_dir=tmpdir) assert Path(tmpdir, "test_package.zip").exists() version_path = Path(tmpdir, "test_package", "version.yaml") @@ -160,7 +165,7 @@ def test_downloads_stable_package(self): def test_downloads_latest_package(self): with TemporaryDirectory() as tmpdir: db.download_packages("test_package", release="latest", - save_dir=tmpdir, from_cache=False) + save_dir=tmpdir) version_path = Path(tmpdir, "test_package", "version.yaml") with version_path.open("r", encoding="utf-8") as file: version_dict = yaml.full_load(file) @@ -172,37 +177,37 @@ def test_downloads_specific_package(self): release = "2022-04-09.dev" with TemporaryDirectory() as tmpdir: db.download_packages(["test_package"], release=release, - save_dir=tmpdir, from_cache=False) + save_dir=tmpdir) version_path = Path(tmpdir, "test_package", "version.yaml") with version_path.open("r", encoding="utf-8") as file: version_dict = yaml.full_load(file) assert version_dict["version"] == release - @pytest.mark.skip(reason="fails too often with timeout") + # @pytest.mark.skip(reason="fails too often with timeout") @pytest.mark.webtest def test_downloads_github_version_of_package_with_semicolon(self): release = "github:728761fc76adb548696205139e4e9a4260401dfc" with TemporaryDirectory() as tmpdir: db.download_packages("ELT", release=release, - save_dir=tmpdir, from_cache=False) + save_dir=tmpdir) filename = Path(tmpdir, "ELT", "EC_sky_25.tbl") assert filename.exists() - @pytest.mark.skip(reason="fails too often with timeout") + # @pytest.mark.skip(reason="fails too often with timeout") @pytest.mark.webtest def test_downloads_github_version_of_package_with_at_symbol(self): release = "github@728761fc76adb548696205139e4e9a4260401dfc" with TemporaryDirectory() as tmpdir: db.download_packages("ELT", release=release, - save_dir=tmpdir, from_cache=False) + save_dir=tmpdir) filename = Path(tmpdir, "ELT", "EC_sky_25.tbl") assert filename.exists() -@pytest.mark.skip(reason="fails too often with timeout") +# @pytest.mark.skip(reason="fails too often with timeout") class TestDownloadGithubFolder: @pytest.mark.webtest def test_downloads_current_package(self): @@ -227,7 +232,7 @@ def test_downloads_with_old_commit_hash(self): def test_throws_for_bad_url(self): with TemporaryDirectory() as tmpdir: url = "https://github.com/AstarVienna/irdb/tree/bogus/MICADO" - with pytest.raises(dbgh.ServerError): + with pytest.raises(db.ServerError): dbgh.download_github_folder(url, output_dir=tmpdir) diff --git a/scopesim/utils.py b/scopesim/utils.py index c7143ec2..5b414d4f 100644 --- a/scopesim/utils.py +++ b/scopesim/utils.py @@ -12,7 +12,7 @@ import functools from docutils.core import publish_string -import requests +import httpx import yaml import numpy as np from matplotlib import pyplot as plt @@ -984,14 +984,14 @@ def return_latest_github_actions_jobs_status( actions_yaml_name="tests.yml", ): """Get the status of the latest test run.""" - response = requests.get( + response = httpx.get( f"https://api.github.com/repos/{owner_name}/{repo_name}/actions/" f"workflows/{actions_yaml_name}/runs?branch={branch}&per_page=1" ) dic = response.json() run_id = dic["workflow_runs"][0]["id"] - response = requests.get( + response = httpx.get( f"https://api.github.com/repos/{owner_name}/{repo_name}/actions/runs/" f"{run_id}/jobs" )