Skip to content

Commit

Permalink
enhance downloading functions
Browse files Browse the repository at this point in the history
  • Loading branch information
wenh06 committed Apr 8, 2024
1 parent 158d5d0 commit 98665fa
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 40 deletions.
1 change: 1 addition & 0 deletions test/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ tensorboard
openpyxl
pre-commit
packaging
gdown
16 changes: 14 additions & 2 deletions test/test_utils/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,16 @@

import pytest

from torch_ecg.utils.download import http_get, url_is_reachable
from torch_ecg.utils.download import _download_from_google_drive, http_get, url_is_reachable

_TMP_DIR = Path(__file__).resolve().parents[2] / "tmp" / "test_download"
_TMP_DIR.mkdir(parents=True, exist_ok=True)


def test_http_get():
url = "https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=1"
# normally, direct downloading from dropbox with `dl=0` will not download the file
# http_get internally replaces `dl=0` with `dl=1` to force download
url = "https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=0"
http_get(url, _TMP_DIR / "action-test-zip-extract", extract=True, filename="test.zip")
shutil.rmtree(_TMP_DIR / "action-test-zip-extract")

Expand Down Expand Up @@ -45,6 +47,16 @@ def test_http_get():
http_get(url, _TMP_DIR, extract=True)
Path(_TMP_DIR / Path(url).name).unlink()

# test downloading from Google Drive
file_id = "1Yys567-MZIMf3eXGJd8bGrsWIvDatbsZ"
url = f"https://drive.google.com/file/d/{file_id}/view?usp=sharing"
with pytest.raises(AssertionError, match="filename can not be inferred from Google Drive URL"):
http_get(url, _TMP_DIR)
http_get(url, _TMP_DIR, filename="torch-ecg-paper.bib", extract=False)
(_TMP_DIR / "torch-ecg-paper.bib").unlink()
_download_from_google_drive(file_id, _TMP_DIR / "torch-ecg-paper.bib")
(_TMP_DIR / "torch-ecg-paper.bib").unlink()


def test_url_is_reachable():
assert url_is_reachable("https://www.dropbox.com/s/oz0n1j3o1m31cbh/action_test.zip?dl=1")
Expand Down
122 changes: 84 additions & 38 deletions torch_ecg/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,48 +69,60 @@ def http_get(
"""
if filename is not None:
assert not (Path(dst_dir) / filename).exists(), "file already exists"
print(f"Downloading {url}.")
if not is_compressed_file(url) and extract:
if filename is not None:
if not is_compressed_file(filename):
# if the URL is from Dropbox, replace the trailing `?dl=0` with `?dl=1`
if "dropbox.com" in url and url.endswith("?dl=0"):
url = f"{url[:-1]}1"
if "drive.google.com" in url:
assert filename is not None, "filename can not be inferred from Google Drive URL."
downloaded_file = tempfile.NamedTemporaryFile(
dir=dst_dir,
delete=False,
)
_download_from_google_drive(url, downloaded_file.name)
df_suffix = _suffix(filename)
else:
print(f"Downloading {url}.")
if not is_compressed_file(url) and extract:
if filename is not None:
if not is_compressed_file(filename):
warnings.warn(
"filename is given, and it is not a `zip` file or a compressed `tar` file. "
"Automatic decompression is turned off.",
RuntimeWarning,
)
extract = False
else:
pass
else:
warnings.warn(
"filename is given, and it is not a `zip` file or a compressed `tar` file. "
"Automatic decompression is turned off.",
"URL must be pointing to a `zip` file or a compressed `tar` file. "
"Automatic decompression is turned off. "
"The user is responsible for decompressing the file manually.",
RuntimeWarning,
)
extract = False
else:
pass
else:
warnings.warn(
"URL must be pointing to a `zip` file or a compressed `tar` file. "
"Automatic decompression is turned off. "
"The user is responsible for decompressing the file manually.",
RuntimeWarning,
)
extract = False
# for example "https://www.dropbox.com/s/xxx/test%3F.zip??dl=1"
# produces pure_url = "https://www.dropbox.com/s/xxx/test?.zip"
pure_url = urllib.parse.unquote(url.split("?")[0])
parent_dir = Path(dst_dir).parent
df_suffix = _suffix(pure_url) if filename is None else _suffix(filename)
downloaded_file = tempfile.NamedTemporaryFile(
dir=parent_dir,
suffix=df_suffix,
delete=False,
)
req = requests.get(url, stream=True, proxies=proxies)
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
if req.status_code == 403 or req.status_code == 404:
raise Exception(f"Could not reach {url}.")
progress = tqdm(unit="B", unit_scale=True, total=total, dynamic_ncols=True, mininterval=1.0)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
downloaded_file.write(chunk)
progress.close()
downloaded_file.close()
# for example "https://www.dropbox.com/s/xxx/test%3F.zip??dl=1"
# produces pure_url = "https://www.dropbox.com/s/xxx/test?.zip"
pure_url = urllib.parse.unquote(url.split("?")[0])
parent_dir = Path(dst_dir).parent
df_suffix = _suffix(pure_url) if filename is None else _suffix(filename)
downloaded_file = tempfile.NamedTemporaryFile(
dir=parent_dir,
suffix=df_suffix,
delete=False,
)
req = requests.get(url, stream=True, proxies=proxies)
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
if req.status_code == 403 or req.status_code == 404:
raise Exception(f"Could not reach {url}.")
progress = tqdm(unit="B", unit_scale=True, total=total, dynamic_ncols=True, mininterval=1.0)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
downloaded_file.write(chunk)
progress.close()
downloaded_file.close()
if extract:
if ".zip" in df_suffix:
_unzip_file(str(downloaded_file.name), str(dst_dir))
Expand Down Expand Up @@ -337,3 +349,37 @@ def url_is_reachable(url: str) -> bool:
return 100 <= r.status_code < 400
except Exception:
return False


def _download_from_google_drive(url_or_id: str, output: Union[str, bytes, os.PathLike], quiet: bool = False) -> None:
"""Download a file from Google Drive.
Parameters
----------
url_or_id : str
The URL of the file or the file ID.
output : `path-like`
The output file path.
quiet : bool, default False
Whether to suppress the output.
"""
try:
import gdown
except (ImportError, ModuleNotFoundError):
raise ImportError("gdown is required to download from Google Drive.")
if url_or_id.startswith("drive.google.com"):
url_or_id = f"https://{url_or_id}"
if not url_or_id.startswith("https://drive.google.com"):
# perhaps is the file ID
url_or_id = f"https://drive.google.com/u/0/uc?id={url_or_id}"
# remove trailing query string
url_or_id = re.sub("/view\\?.*$", "", url_or_id)
if re.match("^https://drive.google.com/file/d/", url_or_id) is not None:
url_or_id = re.sub(
"^https://drive.google.com/file/d/",
"https://drive.google.com/u/0/uc?id=",
url_or_id,
)
print(f"Redirecting to {url_or_id}")
gdown.download(url_or_id, str(output), quiet=quiet)

0 comments on commit 98665fa

Please sign in to comment.