-
Notifications
You must be signed in to change notification settings - Fork 13
/
remote.py
100 lines (83 loc) · 2.48 KB
/
remote.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Load files from remote locations."""
# Core Library
import os
from urllib.request import urlcleanup, urlretrieve
# Third party
import mypy_boto3_s3 as s3
import requests
def load(source_url: str, sink_path: str, policy: str = "load_if_missing") -> None:
"""
Load remote files from source_url to sink_path.
Parameters
----------
source_url : str
sink_path : str
policy : {'load_always', 'load_if_missing'}
"""
file_exists = os.path.isfile(sink_path)
if file_exists and policy == "load_if_missing":
return
known_protocols = [
("http://", load_requests),
("https://", load_requests),
("ftp://", load_urlretrieve),
("s3://", load_aws_s3),
]
for protocol, handler in known_protocols:
if source_url.startswith(protocol):
handler(source_url, sink_path)
break
else:
raise RuntimeError(f"Unknown protocol: source_url='{source_url}'")
def load_requests(source_url: str, sink_path: str) -> None:
"""
Load a file from an URL (e.g. http).
Parameters
----------
source_url : str
Where to load the file from.
sink_path : str
Where the loaded file is stored.
"""
r = requests.get(source_url, stream=True)
if r.status_code == 200:
with open(sink_path, "wb") as f:
for chunk in r:
f.write(chunk)
def load_urlretrieve(source_url: str, sink_path: str) -> None:
"""
Load a file from an URL with urlretrieve.
Parameters
----------
source_url : str
Where to load the file from.
sink_path : str
Where the loaded file is stored.
"""
urlcleanup()
urlretrieve(source_url, sink_path)
def load_aws_s3(source_url: str, sink_path: str) -> None:
"""
Load a file from AWS S3.
Parameters
----------
source_url : str
Where to load the file from.
sink_path : str
Where the loaded file is stored.
"""
# Import here to make this dependency optional
# Third party
import boto3
# Parse parts
url = source_url[len("s3://") :]
bucket, key = url.split("/", 1)
if len(key) == 0:
raise ValueError(f"Key was empty for source_url='{source_url}'")
# Download file
client: s3.S3Client = boto3.client("s3")
response = client.get_object(Bucket=bucket, Key=key)
# Write file to local file
body_string = response["Body"].read()
with open(sink_path, "wb") as f:
f.write(body_string)