-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
_aiohttp.py
518 lines (453 loc) · 20.9 KB
/
_aiohttp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
# --------------------------------------------------------------------------
#
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# The MIT License (MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the ""Software""), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
#
# --------------------------------------------------------------------------
from __future__ import annotations
import sys
from typing import Any, Optional, AsyncIterator as AsyncIteratorType, TYPE_CHECKING, overload, cast, Union, Type
from types import TracebackType
from collections.abc import AsyncIterator
import logging
import asyncio
import codecs
import aiohttp
import aiohttp.client_exceptions
from multidict import CIMultiDict
from azure.core.configuration import ConnectionConfiguration
from azure.core.exceptions import (
ServiceRequestError,
ServiceResponseError,
IncompleteReadError,
)
from azure.core.pipeline import AsyncPipeline
from ._base import HttpRequest
from ._base_async import AsyncHttpTransport, AsyncHttpResponse, _ResponseStopIteration
from ...utils._pipeline_transport_rest_shared import _aiohttp_body_helper
from .._tools import is_rest as _is_rest
from .._tools_async import (
handle_no_stream_rest_response as _handle_no_stream_rest_response,
)
if TYPE_CHECKING:
from ...rest import (
HttpRequest as RestHttpRequest,
AsyncHttpResponse as RestAsyncHttpResponse,
)
from ...rest._aiohttp import RestAioHttpTransportResponse
# Matching requests, because why not?
CONTENT_CHUNK_SIZE = 10 * 1024
_LOGGER = logging.getLogger(__name__)
class AioHttpTransport(AsyncHttpTransport):
"""AioHttp HTTP sender implementation.
Fully asynchronous implementation using the aiohttp library.
:keyword session: The client session.
:paramtype session: ~aiohttp.ClientSession
:keyword bool session_owner: Session owner. Defaults True.
:keyword bool use_env_settings: Uses proxy settings from environment. Defaults to True.
.. admonition:: Example:
.. literalinclude:: ../samples/test_example_async.py
:start-after: [START aiohttp]
:end-before: [END aiohttp]
:language: python
:dedent: 4
:caption: Asynchronous transport with aiohttp.
"""
def __init__(
self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner: bool = True, **kwargs
):
if loop and sys.version_info >= (3, 10):
raise ValueError("Starting with Python 3.10, asyncio doesn’t support loop as a parameter anymore")
self._loop = loop
self._session_owner = session_owner
self.session = session
if not self._session_owner and not self.session:
raise ValueError("session_owner cannot be False if no session is provided")
self.connection_config = ConnectionConfiguration(**kwargs)
self._use_env_settings = kwargs.pop("use_env_settings", True)
async def __aenter__(self):
await self.open()
return self
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
await self.close()
async def open(self):
"""Opens the connection."""
if not self.session and self._session_owner:
jar = aiohttp.DummyCookieJar()
clientsession_kwargs = {
"trust_env": self._use_env_settings,
"cookie_jar": jar,
"auto_decompress": False,
}
if self._loop is not None:
clientsession_kwargs["loop"] = self._loop
self.session = aiohttp.ClientSession(**clientsession_kwargs)
# pyright has trouble to understand that self.session is not None, since we raised at worst in the init
self.session = cast(aiohttp.ClientSession, self.session)
await self.session.__aenter__()
async def close(self):
"""Closes the connection."""
if self._session_owner and self.session:
await self.session.close()
self._session_owner = False
self.session = None
def _build_ssl_config(self, cert, verify):
"""Build the SSL configuration.
:param tuple cert: Cert information
:param bool verify: SSL verification or path to CA file or directory
:rtype: bool or str or :class:`ssl.SSLContext`
:return: SSL Configuration
"""
ssl_ctx = None
if cert or verify not in (True, False):
import ssl
if verify not in (True, False):
ssl_ctx = ssl.create_default_context(cafile=verify)
else:
ssl_ctx = ssl.create_default_context()
if cert:
ssl_ctx.load_cert_chain(*cert)
return ssl_ctx
return verify
def _get_request_data(self, request):
"""Get the request data.
:param request: The request object
:type request: ~azure.core.pipeline.transport.HttpRequest or ~azure.core.rest.HttpRequest
:rtype: bytes or ~aiohttp.FormData
:return: The request data
"""
if request.files:
form_data = aiohttp.FormData()
for form_file, data in request.files.items():
content_type = data[2] if len(data) > 2 else None
try:
form_data.add_field(form_file, data[1], filename=data[0], content_type=content_type)
except IndexError as err:
raise ValueError("Invalid formdata formatting: {}".format(data)) from err
return form_data
return request.data
@overload
async def send(self, request: HttpRequest, **config: Any) -> AsyncHttpResponse:
"""Send the request using this HTTP sender.
Will pre-load the body into memory to be available with a sync method.
Pass stream=True to avoid this behavior.
:param request: The HttpRequest object
:type request: ~azure.core.pipeline.transport.HttpRequest
:keyword any config: Any keyword arguments
:return: The AsyncHttpResponse
:rtype: ~azure.core.pipeline.transport.AsyncHttpResponse
:keyword bool stream: Defaults to False.
:keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url)
:keyword str proxy: will define the proxy to use all the time
"""
@overload
async def send(self, request: RestHttpRequest, **config: Any) -> RestAsyncHttpResponse:
"""Send the `azure.core.rest` request using this HTTP sender.
Will pre-load the body into memory to be available with a sync method.
Pass stream=True to avoid this behavior.
:param request: The HttpRequest object
:type request: ~azure.core.rest.HttpRequest
:keyword any config: Any keyword arguments
:return: The AsyncHttpResponse
:rtype: ~azure.core.rest.AsyncHttpResponse
:keyword bool stream: Defaults to False.
:keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url)
:keyword str proxy: will define the proxy to use all the time
"""
async def send(
self, request: Union[HttpRequest, RestHttpRequest], **config
) -> Union[AsyncHttpResponse, RestAsyncHttpResponse]:
"""Send the request using this HTTP sender.
Will pre-load the body into memory to be available with a sync method.
Pass stream=True to avoid this behavior.
:param request: The HttpRequest object
:type request: ~azure.core.rest.HttpRequest
:keyword any config: Any keyword arguments
:return: The AsyncHttpResponse
:rtype: ~azure.core.rest.AsyncHttpResponse
:keyword bool stream: Defaults to False.
:keyword dict proxies: dict of proxy to used based on protocol. Proxy is a dict (protocol, url)
:keyword str proxy: will define the proxy to use all the time
"""
await self.open()
try:
auto_decompress = self.session.auto_decompress # type: ignore
except AttributeError:
# auto_decompress is introduced in aiohttp 3.7. We need this to handle aiohttp 3.6-.
auto_decompress = False
proxies = config.pop("proxies", None)
if proxies and "proxy" not in config:
# aiohttp needs a single proxy, so iterating until we found the right protocol
# Sort by longest string first, so "http" is not used for "https" ;-)
for protocol in sorted(proxies.keys(), reverse=True):
if request.url.startswith(protocol):
config["proxy"] = proxies[protocol]
break
response: Optional[Union[AsyncHttpResponse, RestAsyncHttpResponse]] = None
config["ssl"] = self._build_ssl_config(
cert=config.pop("connection_cert", self.connection_config.cert),
verify=config.pop("connection_verify", self.connection_config.verify),
)
# If we know for sure there is not body, disable "auto content type"
# Otherwise, aiohttp will send "application/octet-stream" even for empty POST request
# and that break services like storage signature
if not request.data and not request.files:
config["skip_auto_headers"] = ["Content-Type"]
try:
stream_response = config.pop("stream", False)
timeout = config.pop("connection_timeout", self.connection_config.timeout)
read_timeout = config.pop("read_timeout", self.connection_config.read_timeout)
socket_timeout = aiohttp.ClientTimeout(sock_connect=timeout, sock_read=read_timeout)
result = await self.session.request( # type: ignore
request.method,
request.url,
headers=request.headers,
data=self._get_request_data(request),
timeout=socket_timeout,
allow_redirects=False,
**config,
)
if _is_rest(request):
from azure.core.rest._aiohttp import RestAioHttpTransportResponse
response = RestAioHttpTransportResponse(
request=request,
internal_response=result,
block_size=self.connection_config.data_block_size,
decompress=not auto_decompress,
)
if not stream_response:
await _handle_no_stream_rest_response(response)
else:
# Given the associated "if", this else is legacy implementation
# but mypy do not know it, so using a cast
request = cast(HttpRequest, request)
response = AioHttpTransportResponse(
request,
result,
self.connection_config.data_block_size,
decompress=not auto_decompress,
)
if not stream_response:
await response.load_body()
except aiohttp.client_exceptions.ClientResponseError as err:
raise ServiceResponseError(err, error=err) from err
except asyncio.TimeoutError as err:
raise ServiceResponseError(err, error=err) from err
except aiohttp.client_exceptions.ClientError as err:
raise ServiceRequestError(err, error=err) from err
return response
class AioHttpStreamDownloadGenerator(AsyncIterator):
"""Streams the response body data.
:param pipeline: The pipeline object
:type pipeline: ~azure.core.pipeline.AsyncPipeline
:param response: The client response object.
:type response: ~azure.core.rest.AsyncHttpResponse
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the *content-encoding* header.
"""
@overload
def __init__(
self,
pipeline: AsyncPipeline[HttpRequest, AsyncHttpResponse],
response: AioHttpTransportResponse,
*,
decompress: bool = True,
) -> None:
...
@overload
def __init__(
self,
pipeline: AsyncPipeline[RestHttpRequest, RestAsyncHttpResponse],
response: RestAioHttpTransportResponse,
*,
decompress: bool = True,
) -> None:
...
def __init__(
self,
pipeline: AsyncPipeline,
response: Union[AioHttpTransportResponse, RestAioHttpTransportResponse],
*,
decompress: bool = True,
) -> None:
self.pipeline = pipeline
self.request = response.request
self.response = response
self.block_size = response.block_size
self._decompress = decompress
internal_response = response.internal_response
self.content_length = int(internal_response.headers.get("Content-Length", 0))
self._decompressor = None
def __len__(self):
return self.content_length
async def __anext__(self):
internal_response = self.response.internal_response
try:
chunk = await internal_response.content.read(self.block_size)
if not chunk:
raise _ResponseStopIteration()
if not self._decompress:
return chunk
enc = internal_response.headers.get("Content-Encoding")
if not enc:
return chunk
enc = enc.lower()
if enc in ("gzip", "deflate"):
if not self._decompressor:
import zlib
zlib_mode = (16 + zlib.MAX_WBITS) if enc == "gzip" else -zlib.MAX_WBITS
self._decompressor = zlib.decompressobj(wbits=zlib_mode)
chunk = self._decompressor.decompress(chunk)
return chunk
except _ResponseStopIteration:
internal_response.close()
raise StopAsyncIteration() # pylint: disable=raise-missing-from
except aiohttp.client_exceptions.ClientPayloadError as err:
# This is the case that server closes connection before we finish the reading. aiohttp library
# raises ClientPayloadError.
_LOGGER.warning("Incomplete download: %s", err)
internal_response.close()
raise IncompleteReadError(err, error=err) from err
except aiohttp.client_exceptions.ClientResponseError as err:
raise ServiceResponseError(err, error=err) from err
except asyncio.TimeoutError as err:
raise ServiceResponseError(err, error=err) from err
except aiohttp.client_exceptions.ClientError as err:
raise ServiceRequestError(err, error=err) from err
except Exception as err:
_LOGGER.warning("Unable to stream download: %s", err)
internal_response.close()
raise
class AioHttpTransportResponse(AsyncHttpResponse):
"""Methods for accessing response body data.
:param request: The HttpRequest object
:type request: ~azure.core.pipeline.transport.HttpRequest
:param aiohttp_response: Returned from ClientSession.request().
:type aiohttp_response: aiohttp.ClientResponse object
:param block_size: block size of data sent over connection.
:type block_size: int
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the *content-encoding* header.
"""
def __init__(
self,
request: HttpRequest,
aiohttp_response: aiohttp.ClientResponse,
block_size: Optional[int] = None,
*,
decompress: bool = True,
) -> None:
super(AioHttpTransportResponse, self).__init__(request, aiohttp_response, block_size=block_size)
# https://aiohttp.readthedocs.io/en/stable/client_reference.html#aiohttp.ClientResponse
self.status_code = aiohttp_response.status
self.headers = CIMultiDict(aiohttp_response.headers)
self.reason = aiohttp_response.reason
self.content_type = aiohttp_response.headers.get("content-type")
self._content = None
self._decompressed_content = False
self._decompress = decompress
def body(self) -> bytes:
"""Return the whole body as bytes in memory.
:rtype: bytes
:return: The whole response body.
"""
return _aiohttp_body_helper(self)
def text(self, encoding: Optional[str] = None) -> str:
"""Return the whole body as a string.
If encoding is not provided, rely on aiohttp auto-detection.
:param str encoding: The encoding to apply.
:rtype: str
:return: The whole response body as a string.
"""
# super().text detects charset based on self._content() which is compressed
# implement the decoding explicitly here
body = self.body()
ctype = self.headers.get(aiohttp.hdrs.CONTENT_TYPE, "").lower()
mimetype = aiohttp.helpers.parse_mimetype(ctype)
if not encoding:
# extract encoding from mimetype, if caller does not specify
encoding = mimetype.parameters.get("charset")
if encoding:
try:
codecs.lookup(encoding)
except LookupError:
encoding = None
if not encoding:
if mimetype.type == "application" and mimetype.subtype in ["json", "rdap"]:
# RFC 7159 states that the default encoding is UTF-8.
# RFC 7483 defines application/rdap+json
encoding = "utf-8"
elif body is None:
raise RuntimeError("Cannot guess the encoding of a not yet read body")
else:
try:
import cchardet as chardet
except ImportError: # pragma: no cover
try:
import chardet # type: ignore
except ImportError: # pragma: no cover
import charset_normalizer as chardet # type: ignore[no-redef]
# While "detect" can return a dict of float, in this context this won't happen
# The cast is for pyright to be happy
encoding = cast(Optional[str], chardet.detect(body)["encoding"])
if encoding == "utf-8" or encoding is None:
encoding = "utf-8-sig"
return body.decode(encoding)
async def load_body(self) -> None:
"""Load in memory the body, so it could be accessible from sync methods."""
try:
self._content = await self.internal_response.read()
except aiohttp.client_exceptions.ClientPayloadError as err:
# This is the case that server closes connection before we finish the reading. aiohttp library
# raises ClientPayloadError.
raise IncompleteReadError(err, error=err) from err
except aiohttp.client_exceptions.ClientResponseError as err:
raise ServiceResponseError(err, error=err) from err
except asyncio.TimeoutError as err:
raise ServiceResponseError(err, error=err) from err
except aiohttp.client_exceptions.ClientError as err:
raise ServiceRequestError(err, error=err) from err
def stream_download(
self, pipeline: AsyncPipeline[HttpRequest, AsyncHttpResponse], **kwargs
) -> AsyncIteratorType[bytes]:
"""Generator for streaming response body data.
:param pipeline: The pipeline object
:type pipeline: azure.core.pipeline.AsyncPipeline
:keyword bool decompress: If True which is default, will attempt to decode the body based
on the *content-encoding* header.
:rtype: AsyncIterator[bytes]
:return: An iterator of bytes chunks.
"""
return AioHttpStreamDownloadGenerator(pipeline, self, **kwargs)
def __getstate__(self):
# Be sure body is loaded in memory, otherwise not pickable and let it throw
self.body()
state = self.__dict__.copy()
# Remove the unpicklable entries.
state["internal_response"] = None # aiohttp response are not pickable (see headers comments)
state["headers"] = CIMultiDict(self.headers) # MultiDictProxy is not pickable
return state