-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
io.py
209 lines (157 loc) · 5.39 KB
/
io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license.
"""
import os
import socket
import ipaddress
import pkg_resources
import shutil
import glob
import io
import ujson as json
import itertools
import yaml
from urllib3.util import parse_url
from contextlib import contextmanager
from tempfile import mkstemp, mkdtemp
from appdirs import user_config_dir, user_data_dir, user_cache_dir
# full path import results in unit test failures
from .exceptions import InvalidUploadUrlError
_DIR_APP_NAME = 'label-studio'
def good_path(path):
return os.path.abspath(os.path.expanduser(path))
def find_node(package_name, node_path, node_type):
assert node_type in ('dir', 'file', 'any')
basedir = pkg_resources.resource_filename(package_name, '')
node_path = os.path.join(*node_path.split('/')) # linux to windows compatibility
search_by_path = '/' in node_path or '\\' in node_path
for path, dirs, filenames in os.walk(basedir):
if node_type == 'file':
nodes = filenames
elif node_type == 'dir':
nodes = dirs
else:
nodes = filenames + dirs
if search_by_path:
for found_node in nodes:
found_node = os.path.join(path, found_node)
if found_node.endswith(node_path):
return found_node
elif node_path in nodes:
return os.path.join(path, node_path)
else:
raise IOError(
'Could not find "%s" at package "%s"' % (node_path, basedir)
)
def find_file(file):
return find_node('label_studio', file, 'file')
def find_dir(directory):
return find_node('label_studio', directory, 'dir')
@contextmanager
def get_temp_file():
fd, path = mkstemp()
yield path
os.close(fd)
@contextmanager
def get_temp_dir():
dirpath = mkdtemp()
yield dirpath
shutil.rmtree(dirpath)
def get_config_dir():
config_dir = user_config_dir(appname=_DIR_APP_NAME)
try:
os.makedirs(config_dir, exist_ok=True)
except OSError:
pass
return config_dir
def get_data_dir():
data_dir = user_data_dir(appname=_DIR_APP_NAME)
os.makedirs(data_dir, exist_ok=True)
return data_dir
def get_cache_dir():
cache_dir = user_cache_dir(appname=_DIR_APP_NAME)
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
def delete_dir_content(dirpath):
for f in glob.glob(dirpath + '/*'):
remove_file_or_dir(f)
def remove_file_or_dir(path):
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)
def get_all_files_from_dir(d):
out = []
for name in os.listdir(d):
filepath = os.path.join(d, name)
if os.path.isfile(filepath):
out.append(filepath)
return out
def iter_files(root_dir, ext):
for root, _, files in os.walk(root_dir):
for f in files:
if f.lower().endswith(ext):
yield os.path.join(root, f)
def json_load(file, int_keys=False):
with io.open(file, encoding='utf8') as f:
data = json.load(f)
if int_keys:
return {int(k): v for k, v in data.items()}
else:
return data
def read_yaml(filepath):
if not os.path.exists(filepath):
filepath = find_file(filepath)
with io.open(filepath, encoding='utf-8') as f:
data = yaml.load(f, Loader=yaml.FullLoader) # nosec
return data
def read_bytes_stream(filepath):
with open(filepath, mode='rb') as f:
return io.BytesIO(f.read())
def get_all_dirs_from_dir(d):
out = []
for name in os.listdir(d):
filepath = os.path.join(d, name)
if os.path.isdir(filepath):
out.append(filepath)
return out
class SerializableGenerator(list):
"""Generator that is serializable by JSON"""
def __init__(self, iterable):
tmp_body = iter(iterable)
try:
self._head = iter([next(tmp_body)])
self.append(tmp_body)
except StopIteration:
self._head = []
def __iter__(self):
return itertools.chain(self._head, *self[:1])
def validate_upload_url(url, block_local_urls=True):
"""Utility function for defending against SSRF attacks. Raises
- InvalidUploadUrlError if the url is not HTTP[S], or if block_local_urls is enabled
and the URL resolves to a local address.
- LabelStudioApiException if the hostname cannot be resolved
:param url: Url to be checked for validity/safety,
:param block_local_urls: Whether urls that resolve to local/private networks should be allowed.
"""
parsed_url = parse_url(url)
if parsed_url.scheme not in ('http', 'https'):
raise InvalidUploadUrlError
domain = parsed_url.host
try:
ip = socket.gethostbyname(domain)
except socket.error:
from core.utils.exceptions import LabelStudioAPIException
raise LabelStudioAPIException(f"Can't resolve hostname {domain}")
if not block_local_urls:
return
if ip == '0.0.0.0': # nosec
raise InvalidUploadUrlError
local_subnets = [
'127.0.0.0/8',
'10.0.0.0/8',
'172.16.0.0/12',
'192.168.0.0/16',
]
for subnet in local_subnets:
if ipaddress.ip_address(ip) in ipaddress.ip_network(subnet):
raise InvalidUploadUrlError