Skip to content

Commit

Permalink
Refactor download system and move from util to new download module (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
aazuspan committed Dec 6, 2021
1 parent 1bfb8c7 commit da8b2c6
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 68 deletions.
112 changes: 112 additions & 0 deletions wxee/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
import tempfile
import warnings
from typing import Optional

import requests
from requests.adapters import HTTPAdapter
from tqdm.auto import tqdm # type: ignore
from urllib3.util.retry import Retry # type: ignore

from wxee.exceptions import DownloadError


def _create_session(max_attempts: int, backoff: float) -> requests.Session:
"""Create a requests Session with retrying.
References
----------
https://www.peterbe.com/plog/best-practice-with-retries-with-requests
"""
session = requests.Session()
retry = Retry(total=max_attempts, backoff_factor=backoff)
adapter = HTTPAdapter(max_retries=retry)
session.mount("https://", adapter)

return session


def _download_url(
url: str,
out_dir: str,
session: Optional[requests.Session] = None,
timeout: int = 60,
max_attempts: int = 10,
backoff: float = 0.1,
progress: bool = False,
) -> str:
"""Download a file from a URL to a tempfile in a specified directory.
Parameters
----------
url : str
The URL address of the element to download.
out_dir : str
The directory path to save the temporary file to.
session : requests.Session, optional
An optional Session object to use for downloading. If none is given, a session
will be created.
timeout : int
The maximum number of seconds to wait for responses before aborting the connection.
max_attempts : int
The maximum number of times to retry a connection.
backoff : float
A backoff factor to apply on successive failed attempts. Larger numbers will create
increasingly long delays between requests.
progress : bool
If true, a progress bar will be displayed to track download progress.
Returns
-------
str
The path to the downloaded temp file.
"""
session = _create_session(max_attempts, backoff) if not session else session

filename = tempfile.NamedTemporaryFile(mode="w+b", dir=out_dir, delete=False).name
try:
try:
r = session.get(url, stream=True, timeout=timeout)
request_size = int(r.headers.get("content-length", 0))

try:
r.raise_for_status()
except requests.exceptions.HTTPError as e:
raise DownloadError(
"An HTTP Error was encountered. Try increasing 'max_attempts' or running again later."
)

with open(filename, "wb") as dst, tqdm(
total=request_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
desc="Downloading",
disable=not progress,
) as bar:
for chunk in r.iter_content(chunk_size=1024):
size = dst.write(chunk)
bar.update(size)

except requests.exceptions.Timeout as e:
raise DownloadError(
"The connection timed out. Try increasing 'timeout' or running again later."
)
except requests.exceptions.ConnectionError as e:
raise DownloadError(
"A ConnectionError was encountered. Try increasing 'max_attempts' or running again later."
)

# If the download fails for any reason, delete the temp file
except Exception as e:
os.remove(filename)
raise e

downloaded_size = os.path.getsize(filename)

if downloaded_size != request_size:
warnings.warn(
f"Download error: {downloaded_size} bytes out of {request_size} were retrieved. Data may be incomplete or corrupted."
)

return filename
6 changes: 4 additions & 2 deletions wxee/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

from wxee import constants
from wxee.accessors import wx_accessor
from wxee.download import _download_url
from wxee.exceptions import DownloadError
from wxee.utils import (
_dataset_from_files,
_download_url,
_format_date,
_replace_if_null,
_set_nodata,
Expand Down Expand Up @@ -178,7 +178,9 @@ def _url_to_tif(
with tempfile.TemporaryDirectory(
dir=out_dir, prefix=constants.TMP_PREFIX
) as tmp:
zipped = _download_url(url, tmp, progress, max_attempts)
zipped = _download_url(
url, tmp, progress=progress, max_attempts=max_attempts
)
tifs = _unpack_file(zipped, out_dir)
self._process_tifs(tifs, file_per_band, masked, nodata)

Expand Down
66 changes: 0 additions & 66 deletions wxee/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@
import datetime
import itertools
import os
import tempfile
import warnings
from typing import Any, List, Tuple, Union
from zipfile import ZipFile

import ee # type: ignore
import joblib # type: ignore
import rasterio # type: ignore
import requests
import xarray as xr
from requests.adapters import HTTPAdapter
from tqdm.auto import tqdm # type: ignore
Expand Down Expand Up @@ -71,70 +69,6 @@ def _unpack_file(file: str, out_dir: str) -> List[str]:
return [os.path.join(out_dir, file) for file in unzipped]


def _download_url(url: str, out_dir: str, progress: bool, max_attempts: int) -> str:
"""Download a file from a URL to a specified directory.
Parameters
----------
url : str
The URL address of the element to download.
out_dir : str
The directory path to save the temporary file to.
progress : bool
If true, a progress bar will be displayed to track download progress.
max_attempts : int
The maximum number of times to retry a connection.
Returns
-------
str
The path to the downloaded file.
"""
filename = tempfile.NamedTemporaryFile(mode="w+b", dir=out_dir, delete=False).name
r = _create_retry_session(max_attempts).get(url, stream=True)

try:
r.raise_for_status()
except Exception as e:
# Delete the tempfile if it could not be downloaded
os.remove(filename)
raise e

file_size = int(r.headers.get("content-length", 0))

with open(filename, "w+b") as dst, tqdm(
total=file_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
desc="Downloading",
disable=not progress,
) as bar:
for data in r.iter_content(chunk_size=1024):
size = dst.write(data)
bar.update(size)

return filename


def _create_retry_session(max_attempts: int) -> requests.Session:
"""Create a session with automatic retries.
https://www.peterbe.com/plog/best-practice-with-retries-with-requests
"""
session = requests.Session()
retry = Retry(
total=max_attempts, read=max_attempts, connect=max_attempts, backoff_factor=0.1
)

adapter = HTTPAdapter(max_retries=retry)

session.mount("http://", adapter)
session.mount("https://", adapter)

return session


def _dataset_from_files(files: List[str], masked: bool, nodata: int) -> xr.Dataset:
"""Create an xarray.Dataset from a list of raster files."""
das = [_dataarray_from_file(file, masked, nodata) for file in files]
Expand Down

0 comments on commit da8b2c6

Please sign in to comment.