/
data_lake.py
227 lines (203 loc) · 8.97 KB
/
data_lake.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
"""
This module contains integration with Azure Data Lake.
AzureDataLakeHook communicates via a REST API compatible with WebHDFS. Make sure that a
Airflow connection of type `azure_data_lake` exists. Authorization can be done by supplying a
login (=Client ID), password (=Client Secret) and extra fields tenant (Tenant) and account_name (Account Name)
(see connection `azure_data_lake_default` for an example).
"""
from typing import Any, Dict, Optional
from azure.datalake.store import core, lib, multithread
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
class AzureDataLakeHook(BaseHook):
"""
Interacts with Azure Data Lake.
Client ID and client secret should be in user and password parameters.
Tenant and account name should be extra field as
{"tenant": "<TENANT>", "account_name": "ACCOUNT_NAME"}.
:param azure_data_lake_conn_id: Reference to the :ref:`Azure Data Lake connection<howto/connection:adl>`.
"""
conn_name_attr = 'azure_data_lake_conn_id'
default_conn_name = 'azure_data_lake_default'
conn_type = 'azure_data_lake'
hook_name = 'Azure Data Lake'
@staticmethod
def get_connection_form_widgets() -> Dict[str, Any]:
"""Returns connection widgets to add to connection form"""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import StringField
return {
"extra__azure_data_lake__tenant": StringField(
lazy_gettext('Azure Tenant ID'), widget=BS3TextFieldWidget()
),
"extra__azure_data_lake__account_name": StringField(
lazy_gettext('Azure DataLake Store Name'), widget=BS3TextFieldWidget()
),
}
@staticmethod
def get_ui_field_behaviour() -> Dict[str, Any]:
"""Returns custom field behaviour"""
return {
"hidden_fields": ['schema', 'port', 'host', 'extra'],
"relabeling": {
'login': 'Azure Client ID',
'password': 'Azure Client Secret',
},
"placeholders": {
'login': 'client id',
'password': 'secret',
'extra__azure_data_lake__tenant': 'tenant id',
'extra__azure_data_lake__account_name': 'datalake store',
},
}
def __init__(self, azure_data_lake_conn_id: str = default_conn_name) -> None:
super().__init__()
self.conn_id = azure_data_lake_conn_id
self._conn: Optional[core.AzureDLFileSystem] = None
self.account_name: Optional[str] = None
def get_conn(self) -> core.AzureDLFileSystem:
"""Return a AzureDLFileSystem object."""
if not self._conn:
conn = self.get_connection(self.conn_id)
service_options = conn.extra_dejson
self.account_name = service_options.get('account_name') or service_options.get(
'extra__azure_data_lake__account_name'
)
tenant = service_options.get('tenant') or service_options.get('extra__azure_data_lake__tenant')
adl_creds = lib.auth(tenant_id=tenant, client_secret=conn.password, client_id=conn.login)
self._conn = core.AzureDLFileSystem(adl_creds, store_name=self.account_name)
self._conn.connect()
return self._conn
def check_for_file(self, file_path: str) -> bool:
"""
Check if a file exists on Azure Data Lake.
:param file_path: Path and name of the file.
:return: True if the file exists, False otherwise.
:rtype: bool
"""
try:
files = self.get_conn().glob(file_path, details=False, invalidate_cache=True)
return len(files) == 1
except FileNotFoundError:
return False
def upload_file(
self,
local_path: str,
remote_path: str,
nthreads: int = 64,
overwrite: bool = True,
buffersize: int = 4194304,
blocksize: int = 4194304,
**kwargs,
) -> None:
"""
Upload a file to Azure Data Lake.
:param local_path: local path. Can be single file, directory (in which case,
upload recursively) or glob pattern. Recursive glob patterns using `**`
are not supported.
:param remote_path: Remote path to upload to; if multiple files, this is the
directory root to write within.
:param nthreads: Number of threads to use. If None, uses the number of cores.
:param overwrite: Whether to forcibly overwrite existing files/directories.
If False and remote path is a directory, will quit regardless if any files
would be overwritten or not. If True, only matching filenames are actually
overwritten.
:param buffersize: int [2**22]
Number of bytes for internal buffer. This block cannot be bigger than
a chunk and cannot be smaller than a block.
:param blocksize: int [2**22]
Number of bytes for a block. Within each chunk, we write a smaller
block for each API call. This block cannot be bigger than a chunk.
"""
multithread.ADLUploader(
self.get_conn(),
lpath=local_path,
rpath=remote_path,
nthreads=nthreads,
overwrite=overwrite,
buffersize=buffersize,
blocksize=blocksize,
**kwargs,
)
def download_file(
self,
local_path: str,
remote_path: str,
nthreads: int = 64,
overwrite: bool = True,
buffersize: int = 4194304,
blocksize: int = 4194304,
**kwargs,
) -> None:
"""
Download a file from Azure Blob Storage.
:param local_path: local path. If downloading a single file, will write to this
specific file, unless it is an existing directory, in which case a file is
created within it. If downloading multiple files, this is the root
directory to write within. Will create directories as required.
:param remote_path: remote path/globstring to use to find remote files.
Recursive glob patterns using `**` are not supported.
:param nthreads: Number of threads to use. If None, uses the number of cores.
:param overwrite: Whether to forcibly overwrite existing files/directories.
If False and remote path is a directory, will quit regardless if any files
would be overwritten or not. If True, only matching filenames are actually
overwritten.
:param buffersize: int [2**22]
Number of bytes for internal buffer. This block cannot be bigger than
a chunk and cannot be smaller than a block.
:param blocksize: int [2**22]
Number of bytes for a block. Within each chunk, we write a smaller
block for each API call. This block cannot be bigger than a chunk.
"""
multithread.ADLDownloader(
self.get_conn(),
lpath=local_path,
rpath=remote_path,
nthreads=nthreads,
overwrite=overwrite,
buffersize=buffersize,
blocksize=blocksize,
**kwargs,
)
def list(self, path: str) -> list:
"""
List files in Azure Data Lake Storage
:param path: full path/globstring to use to list files in ADLS
"""
if "*" in path:
return self.get_conn().glob(path)
else:
return self.get_conn().walk(path)
def remove(self, path: str, recursive: bool = False, ignore_not_found: bool = True) -> None:
"""
Remove files in Azure Data Lake Storage
:param path: A directory or file to remove in ADLS
:param recursive: Whether to loop into directories in the location and remove the files
:param ignore_not_found: Whether to raise error if file to delete is not found
"""
try:
self.get_conn().remove(path=path, recursive=recursive)
except FileNotFoundError:
if ignore_not_found:
self.log.info("File %s not found", path)
else:
raise AirflowException(f"File {path} not found")