diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index fda905a6e8c..4bfb855e760 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -4,8 +4,8 @@ contact_links: - name: 🤷💻🤦 StackOverflow url: https://stackoverflow.com/questions/tagged/aiohttp about: Please ask typical Q&A here -- name: 💬 Discourse - url: https://aio-libs.discourse.group/ +- name: 💬 Github Discussions + url: https://github.com/aio-libs/aiohttp/discussions about: Please start usage discussions here - name: 💬 Gitter Chat url: https://gitter.im/aio-libs/Lobby diff --git a/.github/workflows/ci.yml b/.github/workflows/ci-cd.yml similarity index 87% rename from .github/workflows/ci.yml rename to .github/workflows/ci-cd.yml index 947d6bab998..e6aa3e3c3d0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci-cd.yml @@ -69,9 +69,9 @@ jobs: # can be scanned by slotscheck. pip install -r requirements/base.txt -c requirements/constraints.txt slotscheck -v -m aiohttp - - name: Install libenchant-dev + - name: Install libenchant run: | - sudo apt install libenchant-dev + sudo apt install libenchant-2-dev - name: Install spell checker run: | pip install -r requirements/doc-spelling.txt -c requirements/constraints.txt @@ -133,15 +133,13 @@ jobs: needs: gen_llhttp strategy: matrix: - pyver: [3.7, 3.8, 3.9, '3.10'] + pyver: [3.8, 3.9, '3.10'] no-extensions: ['', 'Y'] os: [ubuntu, macos, windows] experimental: [false] exclude: - os: macos no-extensions: 'Y' - - os: macos - pyver: 3.7 - os: macos pyver: 3.8 - os: windows @@ -152,15 +150,27 @@ jobs: os: ubuntu experimental: false - os: macos - pyver: "3.11.0-alpha - 3.11.0" + pyver: "3.11" experimental: true no-extensions: 'Y' - os: ubuntu - pyver: "3.11.0-alpha - 3.11.0" + pyver: "3.11" experimental: false no-extensions: 'Y' - os: windows - pyver: "3.11.0-alpha - 3.11.0" + pyver: "3.11" + experimental: true + no-extensions: 'Y' + - os: ubuntu + pyver: "3.12" + experimental: true + no-extensions: 'Y' + - os: macos + pyver: "3.12" + experimental: true + no-extensions: 'Y' + - os: windows + pyver: "3.12" experimental: true no-extensions: 'Y' fail-fast: true @@ -175,6 +185,7 @@ jobs: id: python-install uses: actions/setup-python@v4 with: + allow-prereleases: true python-version: ${{ matrix.pyver }} - name: Get pip cache dir id: pip-cache @@ -359,7 +370,7 @@ jobs: run: | make cythonize - name: Build wheels - uses: pypa/cibuildwheel@v2.10.1 + uses: pypa/cibuildwheel@v2.14.1 env: CIBW_ARCHS_MACOS: x86_64 arm64 universal2 - uses: actions/upload-artifact@v3 @@ -368,13 +379,18 @@ jobs: path: ./wheelhouse/*.whl deploy: - permissions: - contents: write # to make release - name: Deploy - environment: release needs: [build-tarball, build-wheels] runs-on: ubuntu-latest + + permissions: + contents: write # IMPORTANT: mandatory for making GitHub Releases + id-token: write # IMPORTANT: mandatory for trusted publishing & sigstore + + environment: + name: pypi + url: https://pypi.org/p/aiohttp + steps: - name: Checkout uses: actions/checkout@v3 @@ -401,7 +417,27 @@ jobs: name: aiohttp version_file: aiohttp/__init__.py github_token: ${{ secrets.GITHUB_TOKEN }} - pypi_token: ${{ secrets.PYPI_API_TOKEN }} dist_dir: dist fix_issue_regex: "`#(\\d+) `_" fix_issue_repl: "(#\\1)" + + - name: >- + Publish 🐍📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + - name: Sign the dists with Sigstore + uses: sigstore/gh-action-sigstore-python@v1.2.3 + with: + inputs: >- + ./dist/*.tar.gz + ./dist/*.whl + + - name: Upload artifact signatures to GitHub Release + # Confusingly, this action also supports updating releases, not + # just creating them. This is what we want here, since we've manually + # created the release above. + uses: softprops/action-gh-release@v1 + with: + # dist/ contains the built packages, which smoketest-artifacts/ + # contains the signatures and certificates. + files: dist/** diff --git a/.gitmodules b/.gitmodules index 4a06d737c9c..6edb2eea5b2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,4 +1,4 @@ [submodule "vendor/llhttp"] path = vendor/llhttp url = https://github.com/nodejs/llhttp.git - branch = v6.0.6 + branch = v8.x diff --git a/.mypy.ini b/.mypy.ini index 9a888ebffff..b3e17b9731a 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -35,11 +35,5 @@ ignore_missing_imports = True [mypy-gunicorn.*] ignore_missing_imports = True -[mypy-tokio] -ignore_missing_imports = True - -[mypy-uvloop] -ignore_missing_imports = True - [mypy-python_on_whales] ignore_missing_imports = True diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1c8060da182..f41fbfe50ba 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,15 +26,15 @@ repos: hooks: - id: check-merge-conflict - repo: https://github.com/asottile/yesqa - rev: v1.4.0 + rev: v1.5.0 hooks: - id: yesqa - repo: https://github.com/PyCQA/isort - rev: '5.10.1' + rev: '5.12.0' hooks: - id: isort - repo: https://github.com/psf/black - rev: '22.10.0' + rev: '23.7.0' hooks: - id: black language_version: python3 # Should be a command that runs python @@ -72,7 +72,7 @@ repos: - id: detect-private-key exclude: ^examples/ - repo: https://github.com/asottile/pyupgrade - rev: 'v3.3.0' + rev: 'v3.9.0' hooks: - id: pyupgrade args: ['--py36-plus'] @@ -82,6 +82,7 @@ repos: - id: flake8 additional_dependencies: - flake8-docstrings==1.6.0 + - flake8-requirements==1.7.8 exclude: "^docs/" - repo: https://github.com/Lucas-C/pre-commit-hooks-markup rev: v1.0.1 diff --git a/.readthedocs.yml b/.readthedocs.yml index 90fe80896bc..022dd5c3f53 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -6,14 +6,16 @@ version: 2 submodules: - include: all # [] + include: all exclude: [] recursive: true build: - image: latest + os: ubuntu-22.04 + tools: + python: "3.11" + python: - version: 3.8 install: - method: pip path: . diff --git a/CHANGES.rst b/CHANGES.rst index 52977431900..efa6052a7df 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -10,6 +10,21 @@ .. towncrier release notes start +3.8.4 (2023-02-12) +================== + +Bugfixes +-------- + +- Fixed incorrectly overwriting cookies with the same name and domain, but different path. + `#6638 `_ +- Fixed ``ConnectionResetError`` not being raised after client disconnection in SSL environments. + `#7180 `_ + + +---- + + 3.8.3 (2022-09-21) ================== diff --git a/CHANGES/.TEMPLATE.rst b/CHANGES/.TEMPLATE.rst index bc6016baf5c..a27a1994b53 100644 --- a/CHANGES/.TEMPLATE.rst +++ b/CHANGES/.TEMPLATE.rst @@ -12,8 +12,8 @@ {% if definitions[category]['showcontent'] %} {% for text, values in sections[section][category].items() %} -- {{ text }} - {{ values|join(',\n ') }} +- {{ text + '\n' }} + {{ values|join(',\n ') + '\n' }} {% endfor %} {% else %} diff --git a/CHANGES/2304.feature b/CHANGES/2304.feature new file mode 100644 index 00000000000..c89b812cba2 --- /dev/null +++ b/CHANGES/2304.feature @@ -0,0 +1 @@ +Support setting response header parameters max_line_size and max_field_size. diff --git a/CHANGES/3355.bugfix b/CHANGES/3355.bugfix new file mode 100644 index 00000000000..fd002cb00df --- /dev/null +++ b/CHANGES/3355.bugfix @@ -0,0 +1 @@ +Fixed a transport is :data:`None` error -- by :user:`Dreamsorcerer`. diff --git a/CHANGES/3828.feature b/CHANGES/3828.feature deleted file mode 100644 index 9d78d813e95..00000000000 --- a/CHANGES/3828.feature +++ /dev/null @@ -1,4 +0,0 @@ -Disabled implicit switch-back to pure python mode. The build fails loudly if aiohttp -cannot be compiled with C Accelerators. Use `AIOHTTP_NO_EXTENSIONS=1` to explicitly -disable C Extensions complication and switch to Pure-Python mode. Note that Pure-Python -mode is significantly slower than compiled one. diff --git a/CHANGES/5494.bugfix b/CHANGES/5494.bugfix deleted file mode 100644 index 449b6bdf3d6..00000000000 --- a/CHANGES/5494.bugfix +++ /dev/null @@ -1,4 +0,0 @@ -Fixed the multipart POST requests processing to always release file -descriptors for the ``tempfile.Temporaryfile``-created -``_io.BufferedRandom`` instances of files sent within multipart request -bodies via HTTP POST requests. diff --git a/CHANGES/5494.misc b/CHANGES/5494.misc deleted file mode 100644 index 3187859776b..00000000000 --- a/CHANGES/5494.misc +++ /dev/null @@ -1,3 +0,0 @@ -Made sure to always close most of file descriptors and release other -resources in tests. Started ignoring ``ResourceWarning``s in pytest for -warnings that are hard to track. diff --git a/CHANGES/5854.bugfix b/CHANGES/5854.bugfix new file mode 100644 index 00000000000..b7de2f4d232 --- /dev/null +++ b/CHANGES/5854.bugfix @@ -0,0 +1 @@ +Fixed client timeout not working when incoming data is always available without waiting -- by :user:`Dreamsorcerer`. diff --git a/CHANGES/5934.misc b/CHANGES/5934.misc new file mode 100644 index 00000000000..2040bf98a7b --- /dev/null +++ b/CHANGES/5934.misc @@ -0,0 +1 @@ +Add flake8-requirements to linting. diff --git a/CHANGES/6189.bugfix b/CHANGES/6189.bugfix deleted file mode 100644 index 0c22a6f7977..00000000000 --- a/CHANGES/6189.bugfix +++ /dev/null @@ -1 +0,0 @@ -Do not install "examples" as a top-level package. diff --git a/CHANGES/6195.bugfix b/CHANGES/6195.bugfix deleted file mode 100644 index dec44169e22..00000000000 --- a/CHANGES/6195.bugfix +++ /dev/null @@ -1 +0,0 @@ -Restored ability to connect IPv6-only host. diff --git a/CHANGES/6201.bugfix b/CHANGES/6201.bugfix deleted file mode 100644 index d09c2d5bbe8..00000000000 --- a/CHANGES/6201.bugfix +++ /dev/null @@ -1 +0,0 @@ -Remove ``Signal`` from ``__all__``, replace ``aiohttp.Signal`` with ``aiosignal.Signal`` in docs diff --git a/CHANGES/6205.misc b/CHANGES/6205.misc deleted file mode 100644 index 15b60ce0930..00000000000 --- a/CHANGES/6205.misc +++ /dev/null @@ -1,2 +0,0 @@ -Declared the minimum required version of ``setuptools`` v46.4.0 -in the :pep:`517` configuration file -- :user:`jameshilliard`. diff --git a/CHANGES/6240.doc b/CHANGES/6240.doc deleted file mode 100644 index 079548a93ca..00000000000 --- a/CHANGES/6240.doc +++ /dev/null @@ -1 +0,0 @@ -update quick starter demo codes. diff --git a/CHANGES/6274.doc.rst b/CHANGES/6274.doc.rst deleted file mode 100644 index 38d0ea1c992..00000000000 --- a/CHANGES/6274.doc.rst +++ /dev/null @@ -1 +0,0 @@ -Added an explanation of how tiny timeouts affect performance to the client reference document. diff --git a/CHANGES/6276.doc b/CHANGES/6276.doc deleted file mode 100644 index bfd06971499..00000000000 --- a/CHANGES/6276.doc +++ /dev/null @@ -1 +0,0 @@ -Add flake8-docstrings to flake8 configuration, enable subset of checks. diff --git a/CHANGES/6278.doc b/CHANGES/6278.doc deleted file mode 100644 index 2d18217379d..00000000000 --- a/CHANGES/6278.doc +++ /dev/null @@ -1 +0,0 @@ -Added information on running complex applications with additional tasks/processes -- :user:`Dreamsorcerer`. diff --git a/CHANGES/6305.bugfix b/CHANGES/6305.bugfix deleted file mode 100644 index 7d45266f500..00000000000 --- a/CHANGES/6305.bugfix +++ /dev/null @@ -1 +0,0 @@ -Made chunked encoding HTTP header check stricter. diff --git a/CHANGES/6594.feature b/CHANGES/6594.feature new file mode 100644 index 00000000000..4edadb07b3a --- /dev/null +++ b/CHANGES/6594.feature @@ -0,0 +1,2 @@ +Exported ``HTTPMove`` which can be used to catch any redirection request +that has a location -- :user:`dreamsorcerer`. diff --git a/CHANGES/6638.bugfix b/CHANGES/6638.bugfix deleted file mode 100644 index 8154dcfe3f3..00000000000 --- a/CHANGES/6638.bugfix +++ /dev/null @@ -1 +0,0 @@ -Do not overwrite cookies with same name and domain when the path is different. diff --git a/CHANGES/7056.feature b/CHANGES/7056.feature new file mode 100644 index 00000000000..fe22a3c5b10 --- /dev/null +++ b/CHANGES/7056.feature @@ -0,0 +1 @@ +Added `handler_cancellation `_ parameter to cancel web handler on client disconnection. -- by :user:`mosquito` diff --git a/CHANGES/7107.removal b/CHANGES/7107.removal new file mode 100644 index 00000000000..3a015f34b87 --- /dev/null +++ b/CHANGES/7107.removal @@ -0,0 +1 @@ +Removed deprecated ``.loop``, ``.setUpAsync()``, ``.tearDownAsync()`` and ``.get_app()`` from ``AioHTTPTestCase``. diff --git a/CHANGES/7131.feature b/CHANGES/7131.feature new file mode 100644 index 00000000000..bd77aff3613 --- /dev/null +++ b/CHANGES/7131.feature @@ -0,0 +1 @@ +Added support for using Basic Auth credentials from :file:`.netrc` file when making HTTP requests with the :py:class:`~aiohttp.ClientSession` ``trust_env`` argument is set to ``True`` -- by :user:`yuvipanda`. diff --git a/CHANGES/7149.bugfix b/CHANGES/7149.bugfix new file mode 100644 index 00000000000..dc3ac798d7c --- /dev/null +++ b/CHANGES/7149.bugfix @@ -0,0 +1 @@ +changed ``sock_read`` timeout to start after writing has finished, to avoid read timeouts caused by an unfinished write. -- by :user:`dtrifiro` diff --git a/CHANGES/7188.feature b/CHANGES/7188.feature new file mode 100644 index 00000000000..777144aa0e2 --- /dev/null +++ b/CHANGES/7188.feature @@ -0,0 +1 @@ +Added a graceful shutdown period which allows pending tasks to complete before the application's cleanup is called. The period can be adjusted with the ``shutdown_timeout`` parameter -- by :user:`Dreamsorcerer`. diff --git a/CHANGES/7191.misc b/CHANGES/7191.misc new file mode 100644 index 00000000000..55ac9ef9112 --- /dev/null +++ b/CHANGES/7191.misc @@ -0,0 +1 @@ +Made tests pass after the year 2039. diff --git a/CHANGES/7213.misc b/CHANGES/7213.misc new file mode 100644 index 00000000000..77fe5f4b1e5 --- /dev/null +++ b/CHANGES/7213.misc @@ -0,0 +1 @@ +Increase base_protocol.py test coverage. diff --git a/CHANGES/7237.bugfix b/CHANGES/7237.bugfix new file mode 100644 index 00000000000..26f85ea9c95 --- /dev/null +++ b/CHANGES/7237.bugfix @@ -0,0 +1 @@ +Fixed ``PermissionError`` when .netrc is unreadable due to permissions. diff --git a/CHANGES/7240.feature b/CHANGES/7240.feature new file mode 100644 index 00000000000..e656f1369f6 --- /dev/null +++ b/CHANGES/7240.feature @@ -0,0 +1 @@ +Turned access log into no-op when the logger is disabled. diff --git a/CHANGES/7259.bugfix b/CHANGES/7259.bugfix new file mode 100644 index 00000000000..0cc192e18b8 --- /dev/null +++ b/CHANGES/7259.bugfix @@ -0,0 +1 @@ +Fixed missing query in tracing method URLs when using ``yarl`` 1.9+. diff --git a/CHANGES/7281.removal b/CHANGES/7281.removal new file mode 100644 index 00000000000..ddbf457b175 --- /dev/null +++ b/CHANGES/7281.removal @@ -0,0 +1 @@ +Removed support for unsupported ``tokio`` event loop -- by :user:`Dreamsorcerer` diff --git a/CHANGES/7283.doc b/CHANGES/7283.doc new file mode 100644 index 00000000000..71a1a6722fc --- /dev/null +++ b/CHANGES/7283.doc @@ -0,0 +1 @@ +Added a note about possibly needing to update headers when using ``on_response_prepare`` -- by :user:`Dreamsorcerer` diff --git a/CHANGES/7302.bugfix b/CHANGES/7302.bugfix new file mode 100644 index 00000000000..e4b1b5f9201 --- /dev/null +++ b/CHANGES/7302.bugfix @@ -0,0 +1 @@ +Changed max 32-bit timestamp to an aware datetime object, for consistency with the non-32-bit one, and to avoid a DeprecationWarning on Python 3.12. diff --git a/CHANGES/7325.doc b/CHANGES/7325.doc new file mode 100644 index 00000000000..c0a08541d30 --- /dev/null +++ b/CHANGES/7325.doc @@ -0,0 +1 @@ +Complete trust_env parameter description to honor wss_proxy, ws_proxy or no_proxy env diff --git a/CHANGES/7334.doc b/CHANGES/7334.doc new file mode 100644 index 00000000000..95e9f8a1a9c --- /dev/null +++ b/CHANGES/7334.doc @@ -0,0 +1 @@ +Expanded SSL documentation with more examples (e.g. how to use certifi). -- by :user:`Dreamsorcerer` diff --git a/CHANGES/7335.misc b/CHANGES/7335.misc new file mode 100644 index 00000000000..9ccad2ed9d5 --- /dev/null +++ b/CHANGES/7335.misc @@ -0,0 +1 @@ +Fixed annotation of ``ssl`` parameter to disallow ``True``. -- by :user:`Dreamsorcerer` diff --git a/CHANGES/7336.removal b/CHANGES/7336.removal new file mode 100644 index 00000000000..86a92fc9ecc --- /dev/null +++ b/CHANGES/7336.removal @@ -0,0 +1 @@ +Dropped support for Python 3.7. -- by :user:`Dreamsorcerer` diff --git a/CHANGES/7346.feature b/CHANGES/7346.feature new file mode 100644 index 00000000000..9f91e6b7424 --- /dev/null +++ b/CHANGES/7346.feature @@ -0,0 +1,5 @@ +Upgrade the vendored copy of llhttp_ to v8.1.1 -- by :user:`webknjaz`. + +Thanks to :user:`sethmlarson` for pointing this out! + +.. _llhttp: https://llhttp.org diff --git a/CHANGES/7365.feature b/CHANGES/7365.feature new file mode 100644 index 00000000000..3429dd359c3 --- /dev/null +++ b/CHANGES/7365.feature @@ -0,0 +1 @@ +Added typing information to ``RawResponseMessage`` -- by :user:`Gobot1234` diff --git a/CHANGES/7366.feature b/CHANGES/7366.feature new file mode 100644 index 00000000000..8e38f70f898 --- /dev/null +++ b/CHANGES/7366.feature @@ -0,0 +1 @@ +Added information to C parser exceptions to show which character caused the error. -- by :user:`Dreamsorcerer` diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 9fc460021eb..c7e6b95145c 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -84,6 +84,7 @@ Dan Xu Daniel García Daniel Grossmann-Kavanagh Daniel Nelson +Daniele Trifirò Danny Song David Bibb David Michael Brown @@ -155,12 +156,14 @@ Ilya Gruzinov Ingmar Steen Ivan Lakovic Ivan Larin +J. Nick Koston Jacob Champion Jaesung Lee Jake Davis Jakob Ackermann Jakub Wilk Jan Buchar +Jan Gosmann Jarno Elonen Jashandeep Sohi Jean-Baptiste Estival @@ -200,6 +203,7 @@ Krzysztof Blazewicz Kyrylo Perevozchikov Kyungmin Lee Lars P. Søndergaard +Lee LieWhite Liu Hua Louis-Philippe Huberdeau Loïc Lajeanne @@ -353,6 +357,7 @@ Yury Pliner Yury Selivanov Yusuke Tsutsumi Yuval Ofir +Yuvi Panda Zainab Lawal Zeal Wierslee Zlatan Sičanica diff --git a/Makefile b/Makefile index c06fee549bf..f899a12c802 100644 --- a/Makefile +++ b/Makefile @@ -117,11 +117,7 @@ define run_tests_in_docker docker run --rm -ti -v `pwd`:/src -w /src "aiohttp-test-$(1)-$(2)" $(TEST_SPEC) endef -.PHONY: test-3.7-no-extensions test-3.7 test-3.8-no-extensions test-3.8 test-3.9-no-extensions test-3.9 test-3.10-no-extensions test-3.10 -test-3.7-no-extensions: - $(call run_tests_in_docker,3.7,y) -test-3.7: - $(call run_tests_in_docker,3.7,n) +.PHONY: test-3.8-no-extensions test-3.8 test-3.9-no-extensions test test-3.8-no-extensions: $(call run_tests_in_docker,3.8,y) test-3.8: diff --git a/README.rst b/README.rst index 6cd5eebe912..5436adb6834 100644 --- a/README.rst +++ b/README.rst @@ -29,13 +29,13 @@ Async http client/server framework :target: https://docs.aiohttp.org/ :alt: Latest Read The Docs -.. image:: https://img.shields.io/discourse/status?server=https%3A%2F%2Faio-libs.discourse.group - :target: https://aio-libs.discourse.group - :alt: Discourse status +.. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs:matrix.org + :alt: Matrix Room — #aio-libs:matrix.org -.. image:: https://badges.gitter.im/Join%20Chat.svg - :target: https://gitter.im/aio-libs/Lobby - :alt: Chat on Gitter +.. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs-space:matrix.org + :alt: Matrix Space — #aio-libs-space:matrix.org Key Features @@ -150,7 +150,7 @@ Feel free to make a Pull Request for adding your link to these pages! Communication channels ====================== -*aio-libs discourse group*: https://aio-libs.discourse.group +*aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions *gitter chat* https://gitter.im/aio-libs/Lobby @@ -161,7 +161,6 @@ Please add *aiohttp* tag to your question there. Requirements ============ -- Python >= 3.7 - async-timeout_ - charset-normalizer_ - multidict_ diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index a14f2acf83c..6ffafdf1442 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -2,113 +2,106 @@ from typing import TYPE_CHECKING, Tuple -from . import hdrs as hdrs +from . import hdrs from .client import ( - BaseConnector as BaseConnector, - ClientConnectionError as ClientConnectionError, - ClientConnectorCertificateError as ClientConnectorCertificateError, - ClientConnectorError as ClientConnectorError, - ClientConnectorSSLError as ClientConnectorSSLError, - ClientError as ClientError, - ClientHttpProxyError as ClientHttpProxyError, - ClientOSError as ClientOSError, - ClientPayloadError as ClientPayloadError, - ClientProxyConnectionError as ClientProxyConnectionError, - ClientRequest as ClientRequest, - ClientResponse as ClientResponse, - ClientResponseError as ClientResponseError, - ClientSession as ClientSession, - ClientSSLError as ClientSSLError, - ClientTimeout as ClientTimeout, - ClientWebSocketResponse as ClientWebSocketResponse, - ContentTypeError as ContentTypeError, - Fingerprint as Fingerprint, - InvalidURL as InvalidURL, - NamedPipeConnector as NamedPipeConnector, - RequestInfo as RequestInfo, - ServerConnectionError as ServerConnectionError, - ServerDisconnectedError as ServerDisconnectedError, - ServerFingerprintMismatch as ServerFingerprintMismatch, - ServerTimeoutError as ServerTimeoutError, - TCPConnector as TCPConnector, - TooManyRedirects as TooManyRedirects, - UnixConnector as UnixConnector, - WSServerHandshakeError as WSServerHandshakeError, - request as request, + BaseConnector, + ClientConnectionError, + ClientConnectorCertificateError, + ClientConnectorError, + ClientConnectorSSLError, + ClientError, + ClientHttpProxyError, + ClientOSError, + ClientPayloadError, + ClientProxyConnectionError, + ClientRequest, + ClientResponse, + ClientResponseError, + ClientSession, + ClientSSLError, + ClientTimeout, + ClientWebSocketResponse, + ContentTypeError, + Fingerprint, + InvalidURL, + NamedPipeConnector, + RequestInfo, + ServerConnectionError, + ServerDisconnectedError, + ServerFingerprintMismatch, + ServerTimeoutError, + TCPConnector, + TooManyRedirects, + UnixConnector, + WSServerHandshakeError, + request, ) -from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar -from .formdata import FormData as FormData +from .cookiejar import CookieJar, DummyCookieJar +from .formdata import FormData from .helpers import BasicAuth, ChainMapProxy, ETag from .http import ( - HttpVersion as HttpVersion, - HttpVersion10 as HttpVersion10, - HttpVersion11 as HttpVersion11, - WebSocketError as WebSocketError, - WSCloseCode as WSCloseCode, - WSMessage as WSMessage, - WSMsgType as WSMsgType, + HttpVersion, + HttpVersion10, + HttpVersion11, + WebSocketError, + WSCloseCode, + WSMessage, + WSMsgType, ) from .multipart import ( - BadContentDispositionHeader as BadContentDispositionHeader, - BadContentDispositionParam as BadContentDispositionParam, - BodyPartReader as BodyPartReader, - MultipartReader as MultipartReader, - MultipartWriter as MultipartWriter, - content_disposition_filename as content_disposition_filename, - parse_content_disposition as parse_content_disposition, + BadContentDispositionHeader, + BadContentDispositionParam, + BodyPartReader, + MultipartReader, + MultipartWriter, + content_disposition_filename, + parse_content_disposition, ) from .payload import ( - PAYLOAD_REGISTRY as PAYLOAD_REGISTRY, - AsyncIterablePayload as AsyncIterablePayload, - BufferedReaderPayload as BufferedReaderPayload, - BytesIOPayload as BytesIOPayload, - BytesPayload as BytesPayload, - IOBasePayload as IOBasePayload, - JsonPayload as JsonPayload, - Payload as Payload, - StringIOPayload as StringIOPayload, - StringPayload as StringPayload, - TextIOPayload as TextIOPayload, - get_payload as get_payload, - payload_type as payload_type, -) -from .resolver import ( - AsyncResolver as AsyncResolver, - DefaultResolver as DefaultResolver, - ThreadedResolver as ThreadedResolver, + PAYLOAD_REGISTRY, + AsyncIterablePayload, + BufferedReaderPayload, + BytesIOPayload, + BytesPayload, + IOBasePayload, + JsonPayload, + Payload, + StringIOPayload, + StringPayload, + TextIOPayload, + get_payload, + payload_type, ) +from .resolver import AsyncResolver, DefaultResolver, ThreadedResolver from .streams import ( - EMPTY_PAYLOAD as EMPTY_PAYLOAD, - DataQueue as DataQueue, - EofStream as EofStream, - FlowControlDataQueue as FlowControlDataQueue, - StreamReader as StreamReader, + EMPTY_PAYLOAD, + DataQueue, + EofStream, + FlowControlDataQueue, + StreamReader, ) from .tracing import ( - TraceConfig as TraceConfig, - TraceConnectionCreateEndParams as TraceConnectionCreateEndParams, - TraceConnectionCreateStartParams as TraceConnectionCreateStartParams, - TraceConnectionQueuedEndParams as TraceConnectionQueuedEndParams, - TraceConnectionQueuedStartParams as TraceConnectionQueuedStartParams, - TraceConnectionReuseconnParams as TraceConnectionReuseconnParams, - TraceDnsCacheHitParams as TraceDnsCacheHitParams, - TraceDnsCacheMissParams as TraceDnsCacheMissParams, - TraceDnsResolveHostEndParams as TraceDnsResolveHostEndParams, - TraceDnsResolveHostStartParams as TraceDnsResolveHostStartParams, - TraceRequestChunkSentParams as TraceRequestChunkSentParams, - TraceRequestEndParams as TraceRequestEndParams, - TraceRequestExceptionParams as TraceRequestExceptionParams, - TraceRequestRedirectParams as TraceRequestRedirectParams, - TraceRequestStartParams as TraceRequestStartParams, - TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams, + TraceConfig, + TraceConnectionCreateEndParams, + TraceConnectionCreateStartParams, + TraceConnectionQueuedEndParams, + TraceConnectionQueuedStartParams, + TraceConnectionReuseconnParams, + TraceDnsCacheHitParams, + TraceDnsCacheMissParams, + TraceDnsResolveHostEndParams, + TraceDnsResolveHostStartParams, + TraceRequestChunkSentParams, + TraceRequestEndParams, + TraceRequestExceptionParams, + TraceRequestRedirectParams, + TraceRequestStartParams, + TraceResponseChunkReceivedParams, ) -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover # At runtime these are lazy-loaded at the bottom of the file. - from .worker import ( - GunicornUVLoopWebWorker as GunicornUVLoopWebWorker, - GunicornWebWorker as GunicornWebWorker, - ) + from .worker import GunicornUVLoopWebWorker, GunicornWebWorker __all__: Tuple[str, ...] = ( "hdrs", diff --git a/aiohttp/_http_parser.pyx b/aiohttp/_http_parser.pyx index bebd9894374..4f39dd0c978 100644 --- a/aiohttp/_http_parser.pyx +++ b/aiohttp/_http_parser.pyx @@ -546,7 +546,13 @@ cdef class HttpParser: ex = self._last_error self._last_error = None else: - ex = parser_error_from_errno(self._cparser) + after = cparser.llhttp_get_error_pos(self._cparser) + before = data[:after - self.py_buf.buf] + after_b = after.split(b"\n", 1)[0] + before = before.rsplit(b"\n", 1)[-1] + data = before + after_b + pointer = " " * (len(repr(before))-1) + "^" + ex = parser_error_from_errno(self._cparser, data, pointer) self._payload = None raise ex @@ -797,7 +803,7 @@ cdef int cb_on_chunk_complete(cparser.llhttp_t* parser) except -1: return 0 -cdef parser_error_from_errno(cparser.llhttp_t* parser): +cdef parser_error_from_errno(cparser.llhttp_t* parser, data, pointer): cdef cparser.llhttp_errno_t errno = cparser.llhttp_get_errno(parser) cdef bytes desc = cparser.llhttp_get_error_reason(parser) @@ -829,4 +835,4 @@ cdef parser_error_from_errno(cparser.llhttp_t* parser): else: cls = BadHttpMessage - return cls(desc.decode('latin-1')) + return cls("{}:\n\n {!r}\n {}".format(desc.decode("latin-1"), data, pointer)) diff --git a/aiohttp/base_protocol.py b/aiohttp/base_protocol.py index 8189835e211..4c9f0a752e3 100644 --- a/aiohttp/base_protocol.py +++ b/aiohttp/base_protocol.py @@ -18,11 +18,15 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._loop: asyncio.AbstractEventLoop = loop self._paused = False self._drain_waiter: Optional[asyncio.Future[None]] = None - self._connection_lost = False self._reading_paused = False self.transport: Optional[asyncio.Transport] = None + @property + def connected(self) -> bool: + """Return True if the connection is open.""" + return self.transport is not None + def pause_writing(self) -> None: assert not self._paused self._paused = True @@ -59,7 +63,6 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = tr def connection_lost(self, exc: Optional[BaseException]) -> None: - self._connection_lost = True # Wake up the writer if currently paused. self.transport = None if not self._paused: @@ -76,7 +79,7 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: waiter.set_exception(exc) async def _drain_helper(self) -> None: - if self._connection_lost: + if not self.connected: raise ConnectionResetError("Connection lost") if not self._paused: return diff --git a/aiohttp/client.py b/aiohttp/client.py index c40745771cb..9050cc4120c 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -12,6 +12,7 @@ from contextlib import suppress from types import SimpleNamespace, TracebackType from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -21,6 +22,7 @@ Generic, Iterable, List, + Literal, Mapping, Optional, Set, @@ -37,44 +39,39 @@ from . import hdrs, http, payload from .abc import AbstractCookieJar from .client_exceptions import ( - ClientConnectionError as ClientConnectionError, - ClientConnectorCertificateError as ClientConnectorCertificateError, - ClientConnectorError as ClientConnectorError, - ClientConnectorSSLError as ClientConnectorSSLError, - ClientError as ClientError, - ClientHttpProxyError as ClientHttpProxyError, - ClientOSError as ClientOSError, - ClientPayloadError as ClientPayloadError, - ClientProxyConnectionError as ClientProxyConnectionError, - ClientResponseError as ClientResponseError, - ClientSSLError as ClientSSLError, - ContentTypeError as ContentTypeError, - InvalidURL as InvalidURL, - ServerConnectionError as ServerConnectionError, - ServerDisconnectedError as ServerDisconnectedError, - ServerFingerprintMismatch as ServerFingerprintMismatch, - ServerTimeoutError as ServerTimeoutError, - TooManyRedirects as TooManyRedirects, - WSServerHandshakeError as WSServerHandshakeError, + ClientConnectionError, + ClientConnectorCertificateError, + ClientConnectorError, + ClientConnectorSSLError, + ClientError, + ClientHttpProxyError, + ClientOSError, + ClientPayloadError, + ClientProxyConnectionError, + ClientResponseError, + ClientSSLError, + ContentTypeError, + InvalidURL, + ServerConnectionError, + ServerDisconnectedError, + ServerFingerprintMismatch, + ServerTimeoutError, + TooManyRedirects, + WSServerHandshakeError, ) from .client_reqrep import ( - SSL_ALLOWED_TYPES as SSL_ALLOWED_TYPES, - ClientRequest as ClientRequest, - ClientResponse as ClientResponse, - Fingerprint as Fingerprint, - RequestInfo as RequestInfo, + SSL_ALLOWED_TYPES, + ClientRequest, + ClientResponse, + Fingerprint, + RequestInfo, ) from .client_ws import ( DEFAULT_WS_CLIENT_TIMEOUT, - ClientWebSocketResponse as ClientWebSocketResponse, + ClientWebSocketResponse, ClientWSTimeout, ) -from .connector import ( - BaseConnector as BaseConnector, - NamedPipeConnector as NamedPipeConnector, - TCPConnector as TCPConnector, - UnixConnector as UnixConnector, -) +from .connector import BaseConnector, NamedPipeConnector, TCPConnector, UnixConnector from .cookiejar import CookieJar from .helpers import ( _SENTINEL, @@ -131,10 +128,10 @@ ) -try: +if TYPE_CHECKING: from ssl import SSLContext -except ImportError: # pragma: no cover - SSLContext = object # type: ignore[misc,assignment] +else: + SSLContext = None @dataclasses.dataclass(frozen=True) @@ -191,6 +188,8 @@ class ClientSession: "_ws_response_class", "_trace_configs", "_read_bufsize", + "_max_line_size", + "_max_field_size", ) def __init__( @@ -218,6 +217,8 @@ def __init__( requote_redirect_url: bool = True, trace_configs: Optional[List[TraceConfig]] = None, read_bufsize: int = 2**16, + max_line_size: int = 8190, + max_field_size: int = 8190, ) -> None: if base_url is None or isinstance(base_url, URL): self._base_url: Optional[URL] = base_url @@ -266,6 +267,8 @@ def __init__( self._trust_env = trust_env self._requote_redirect_url = requote_redirect_url self._read_bufsize = read_bufsize + self._max_line_size = max_line_size + self._max_field_size = max_field_size # Convert to list of tuples if headers: @@ -293,21 +296,16 @@ def __init_subclass__(cls: Type["ClientSession"]) -> None: ) def __del__(self, _warnings: Any = warnings) -> None: - try: - if not self.closed: - _warnings.warn( - f"Unclosed client session {self!r}", - ResourceWarning, - source=self, - ) - context = {"client_session": self, "message": "Unclosed client session"} - if self._source_traceback is not None: - context["source_traceback"] = self._source_traceback - self._loop.call_exception_handler(context) - except AttributeError: - # loop was not initialized yet, - # either self._connector or self._loop doesn't exist - pass + if not self.closed: + _warnings.warn( + f"Unclosed client session {self!r}", + ResourceWarning, + source=self, + ) + context = {"client_session": self, "message": "Unclosed client session"} + if self._source_traceback is not None: + context["source_traceback"] = self._source_traceback + self._loop.call_exception_handler(context) def request( self, method: str, url: StrOrURL, **kwargs: Any @@ -347,13 +345,14 @@ async def _request( proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, timeout: Union[ClientTimeout, _SENTINEL, None] = sentinel, - ssl: Optional[Union[SSLContext, bool, Fingerprint]] = None, + ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, proxy_headers: Optional[LooseHeaders] = None, trace_request_ctx: Optional[SimpleNamespace] = None, read_bufsize: Optional[int] = None, auto_decompress: Optional[bool] = None, + max_line_size: Optional[int] = None, + max_field_size: Optional[int] = None, ) -> ClientResponse: - # NOTE: timeout clamps existing connect and read timeouts. We cannot # set the default to None because we need to detect if the user wants # to use the existing timeouts by setting timeout to None. @@ -377,6 +376,7 @@ async def _request( redirects = 0 history = [] version = self._version + params = params or {} # Merge with default headers and transform to CIMultiDict headers = self._prepare_headers(headers) @@ -415,6 +415,12 @@ async def _request( if auto_decompress is None: auto_decompress = self._auto_decompress + if max_line_size is None: + max_line_size = self._max_line_size + + if max_field_size is None: + max_field_size = self._max_field_size + traces = [ Trace( self, @@ -492,6 +498,7 @@ async def _request( ssl=ssl, proxy_headers=proxy_headers, traces=traces, + trust_env=self.trust_env, ) # connection timeout @@ -520,6 +527,8 @@ async def _request( read_timeout=real_timeout.sock_read, read_bufsize=read_bufsize, timeout_ceil_threshold=self._connector._timeout_ceil_threshold, + max_line_size=max_line_size, + max_field_size=max_field_size, ) try: @@ -544,7 +553,6 @@ async def _request( # redirects if resp.status in (301, 302, 303, 307, 308) and allow_redirects: - for trace in traces: await trace.send_request_redirect( method, url.update_query(params), headers, resp @@ -608,7 +616,7 @@ async def _request( headers.pop(hdrs.AUTHORIZATION, None) url = parsed_url - params = None + params = {} resp.release() continue @@ -670,7 +678,7 @@ def ws_connect( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, bool, None, Fingerprint] = None, + ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None, proxy_headers: Optional[LooseHeaders] = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, @@ -716,7 +724,7 @@ async def _ws_connect( headers: Optional[LooseHeaders] = None, proxy: Optional[StrOrURL] = None, proxy_auth: Optional[BasicAuth] = None, - ssl: Union[SSLContext, bool, None, Fingerprint] = None, + ssl: Union[SSLContext, Literal[False], None, Fingerprint] = None, proxy_headers: Optional[LooseHeaders] = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, @@ -753,7 +761,7 @@ async def _ws_connect( default_headers = { hdrs.UPGRADE: "websocket", - hdrs.CONNECTION: "upgrade", + hdrs.CONNECTION: "Upgrade", hdrs.SEC_WEBSOCKET_VERSION: "13", } @@ -1086,7 +1094,6 @@ async def __aexit__( class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType]): - __slots__ = ("_coro", "_resp") def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None: @@ -1143,7 +1150,6 @@ async def __aexit__( class _SessionRequestContextManager: - __slots__ = ("_coro", "_resp", "_session") def __init__( @@ -1199,6 +1205,8 @@ def request( version: HttpVersion = http.HttpVersion11, connector: Optional[BaseConnector] = None, read_bufsize: Optional[int] = None, + max_line_size: int = 8190, + max_field_size: int = 8190, ) -> _SessionRequestContextManager: """Constructs and sends a request. @@ -1269,6 +1277,8 @@ def request( proxy=proxy, proxy_auth=proxy_auth, read_bufsize=read_bufsize, + max_line_size=max_line_size, + max_field_size=max_field_size, ), session, ) diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index 0e6c414ea7d..bfa9ea84a97 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -153,11 +153,12 @@ def set_response_params( read_timeout: Optional[float] = None, read_bufsize: int = 2**16, timeout_ceil_threshold: float = 5, + max_line_size: int = 8190, + max_field_size: int = 8190, ) -> None: self._skip_payload = skip_payload self._read_timeout = read_timeout - self._reschedule_timeout() self._timeout_ceil_threshold = timeout_ceil_threshold @@ -170,6 +171,8 @@ def set_response_params( response_with_body=not skip_payload, read_until_eof=read_until_eof, auto_decompress=auto_decompress, + max_line_size=max_line_size, + max_field_size=max_field_size, ) if self._tail: @@ -193,6 +196,9 @@ def _reschedule_timeout(self) -> None: else: self._read_timeout_handle = None + def start_timeout(self) -> None: + self._reschedule_timeout() + def _on_read_timeout(self) -> None: exc = ServerTimeoutError("Timeout on reading data from socket") self.set_exception(exc) diff --git a/aiohttp/client_reqrep.py b/aiohttp/client_reqrep.py index 9478e8d36c5..3d7a90d6d2d 100644 --- a/aiohttp/client_reqrep.py +++ b/aiohttp/client_reqrep.py @@ -1,5 +1,6 @@ import asyncio import codecs +import contextlib import dataclasses import functools import io @@ -16,6 +17,7 @@ Dict, Iterable, List, + Literal, Mapping, Optional, Tuple, @@ -37,6 +39,7 @@ InvalidURL, ServerFingerprintMismatch, ) +from .compression_utils import HAS_BROTLI from .formdata import FormData from .hdrs import CONTENT_TYPE from .helpers import ( @@ -44,14 +47,15 @@ BasicAuth, HeadersMixin, TimerNoop, + basicauth_from_netrc, is_expected_content_type, + netrc_from_env, noop, parse_mimetype, reify, set_result, ) from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11, StreamWriter -from .http_parser import HAS_BROTLI from .log import client_logger from .streams import StreamReader from .typedefs import ( @@ -153,7 +157,7 @@ class ConnectionKey: host: str port: Optional[int] is_ssl: bool - ssl: Union[SSLContext, None, bool, Fingerprint] + ssl: Union[SSLContext, None, Literal[False], Fingerprint] proxy: Optional[URL] proxy_auth: Optional[BasicAuth] proxy_headers_hash: Optional[int] # hash(CIMultiDict) @@ -207,9 +211,10 @@ def __init__( proxy_auth: Optional[BasicAuth] = None, timer: Optional[BaseTimerContext] = None, session: Optional["ClientSession"] = None, - ssl: Union[SSLContext, bool, Fingerprint, None] = None, + ssl: Union[SSLContext, Literal[False], Fingerprint, None] = None, proxy_headers: Optional[LooseHeaders] = None, traces: Optional[List["Trace"]] = None, + trust_env: bool = False, ): match = _CONTAINS_CONTROL_CHAR_RE.search(method) if match: @@ -251,7 +256,7 @@ def __init__( self.update_auto_headers(skip_auto_headers) self.update_cookies(cookies) self.update_content_encoding(data) - self.update_auth(auth) + self.update_auth(auth, trust_env) self.update_proxy(proxy, proxy_auth, proxy_headers) self.update_body_from_data(data) @@ -266,7 +271,7 @@ def is_ssl(self) -> bool: return self.url.scheme in ("https", "wss") @property - def ssl(self) -> Union["SSLContext", None, bool, Fingerprint]: + def ssl(self) -> Union["SSLContext", None, Literal[False], Fingerprint]: return self._ssl @property @@ -428,10 +433,14 @@ def update_transfer_encoding(self) -> None: if hdrs.CONTENT_LENGTH not in self.headers: self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body)) - def update_auth(self, auth: Optional[BasicAuth]) -> None: + def update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> None: """Set basic auth.""" if auth is None: auth = self.auth + if auth is None and trust_env and self.url.host is not None: + netrc_obj = netrc_from_env() + with contextlib.suppress(LookupError): + auth = basicauth_from_netrc(netrc_obj, self.url.host) if auth is None: return @@ -472,7 +481,7 @@ def update_body_from_data(self, body: Any) -> None: # copy payload headers assert body.headers - for (key, value) in body.headers.items(): + for key, value in body.headers.items(): if key in self.headers: continue if key in self.skip_auto_headers: @@ -551,6 +560,8 @@ async def write_bytes( protocol.set_exception(exc) except Exception as exc: protocol.set_exception(exc) + else: + protocol.start_timeout() finally: self._writer = None @@ -660,7 +671,6 @@ async def _on_headers_request_sent( class ClientResponse(HeadersMixin): - # Some of these attributes are None when created, # but will be set by the start() method. # As the end user will likely never see the None values, we cheat the types below. diff --git a/aiohttp/compression_utils.py b/aiohttp/compression_utils.py new file mode 100644 index 00000000000..8abc4fa7c3c --- /dev/null +++ b/aiohttp/compression_utils.py @@ -0,0 +1,148 @@ +import asyncio +import zlib +from concurrent.futures import Executor +from typing import Optional, cast + +try: + import brotli + + HAS_BROTLI = True +except ImportError: # pragma: no cover + HAS_BROTLI = False + +MAX_SYNC_CHUNK_SIZE = 1024 + + +def encoding_to_mode( + encoding: Optional[str] = None, + suppress_deflate_header: bool = False, +) -> int: + if encoding == "gzip": + return 16 + zlib.MAX_WBITS + + return -zlib.MAX_WBITS if suppress_deflate_header else zlib.MAX_WBITS + + +class ZlibBaseHandler: + def __init__( + self, + mode: int, + executor: Optional[Executor] = None, + max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, + ): + self._mode = mode + self._executor = executor + self._max_sync_chunk_size = max_sync_chunk_size + + +class ZLibCompressor(ZlibBaseHandler): + def __init__( + self, + encoding: Optional[str] = None, + suppress_deflate_header: bool = False, + level: Optional[int] = None, + wbits: Optional[int] = None, + strategy: int = zlib.Z_DEFAULT_STRATEGY, + executor: Optional[Executor] = None, + max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, + ): + super().__init__( + mode=encoding_to_mode(encoding, suppress_deflate_header) + if wbits is None + else wbits, + executor=executor, + max_sync_chunk_size=max_sync_chunk_size, + ) + if level is None: + self._compressor = zlib.compressobj(wbits=self._mode, strategy=strategy) + else: + self._compressor = zlib.compressobj( + wbits=self._mode, strategy=strategy, level=level + ) + + def compress_sync(self, data: bytes) -> bytes: + return self._compressor.compress(data) + + async def compress(self, data: bytes) -> bytes: + if ( + self._max_sync_chunk_size is not None + and len(data) > self._max_sync_chunk_size + ): + return await asyncio.get_event_loop().run_in_executor( + self._executor, self.compress_sync, data + ) + return self.compress_sync(data) + + def flush(self, mode: int = zlib.Z_FINISH) -> bytes: + return self._compressor.flush(mode) + + +class ZLibDecompressor(ZlibBaseHandler): + def __init__( + self, + encoding: Optional[str] = None, + suppress_deflate_header: bool = False, + executor: Optional[Executor] = None, + max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, + ): + super().__init__( + mode=encoding_to_mode(encoding, suppress_deflate_header), + executor=executor, + max_sync_chunk_size=max_sync_chunk_size, + ) + self._decompressor = zlib.decompressobj(wbits=self._mode) + + def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes: + return self._decompressor.decompress(data, max_length) + + async def decompress(self, data: bytes, max_length: int = 0) -> bytes: + if ( + self._max_sync_chunk_size is not None + and len(data) > self._max_sync_chunk_size + ): + return await asyncio.get_event_loop().run_in_executor( + self._executor, self.decompress_sync, data, max_length + ) + return self.decompress_sync(data, max_length) + + def flush(self, length: int = 0) -> bytes: + return ( + self._decompressor.flush(length) + if length > 0 + else self._decompressor.flush() + ) + + @property + def eof(self) -> bool: + return self._decompressor.eof + + @property + def unconsumed_tail(self) -> bytes: + return self._decompressor.unconsumed_tail + + @property + def unused_data(self) -> bytes: + return self._decompressor.unused_data + + +class BrotliDecompressor: + # Supports both 'brotlipy' and 'Brotli' packages + # since they share an import name. The top branches + # are for 'brotlipy' and bottom branches for 'Brotli' + def __init__(self) -> None: + if not HAS_BROTLI: + raise RuntimeError( + "The brotli decompression is not available. " + "Please install `Brotli` module" + ) + self._obj = brotli.Decompressor() + + def decompress_sync(self, data: bytes) -> bytes: + if hasattr(self._obj, "decompress"): + return cast(bytes, self._obj.decompress(data)) + return cast(bytes, self._obj.process(data)) + + def flush(self) -> bytes: + if hasattr(self._obj, "flush"): + return cast(bytes, self._obj.flush()) + return b"" diff --git a/aiohttp/connector.py b/aiohttp/connector.py index a0c9ea7482b..3d9f7b59f7b 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -22,6 +22,7 @@ Dict, Iterator, List, + Literal, Optional, Set, Tuple, @@ -69,7 +70,6 @@ class Connection: - _source_traceback = None _transport = None @@ -192,7 +192,6 @@ def __init__( enable_cleanup_closed: bool = False, timeout_ceil_threshold: float = 5, ) -> None: - if force_close: if keepalive_timeout is not None and keepalive_timeout is not sentinel: raise ValueError( @@ -466,7 +465,7 @@ def _available_connections(self, key: "ConnectionKey") -> int: return available async def connect( - self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> Connection: """Get from pool or create new connection.""" key = req.connection_key @@ -661,7 +660,7 @@ def _release( ) async def _create_connection( - self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: raise NotImplementedError() @@ -736,7 +735,7 @@ def __init__( use_dns_cache: bool = True, ttl_dns_cache: Optional[int] = 10, family: int = 0, - ssl: Union[None, bool, Fingerprint, SSLContext] = None, + ssl: Union[None, Literal[False], Fingerprint, SSLContext] = None, local_addr: Optional[Tuple[str, int]] = None, resolver: Optional[AbstractResolver] = None, keepalive_timeout: Union[None, float, _SENTINEL] = sentinel, @@ -813,7 +812,6 @@ async def _resolve_host( ] if not self._use_dns_cache: - if traces: for trace in traces: await trace.send_dns_resolvehost_start(host) @@ -851,7 +849,6 @@ async def _resolve_host( for trace in traces: await trace.send_dns_cache_miss(host) try: - if traces: for trace in traces: await trace.send_dns_resolvehost_start(host) @@ -874,7 +871,7 @@ async def _resolve_host( return self._cached_hosts.next_addrs(key) async def _create_connection( - self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: """Create connection. @@ -910,7 +907,7 @@ def _make_ssl_context(verified: bool) -> SSLContext: sslcontext.set_default_verify_paths() return sslcontext - def _get_ssl_context(self, req: "ClientRequest") -> Optional[SSLContext]: + def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: """Logic to get the correct SSL context 0. if req.ssl is false, return None @@ -943,7 +940,7 @@ def _get_ssl_context(self, req: "ClientRequest") -> Optional[SSLContext]: else: return None - def _get_fingerprint(self, req: "ClientRequest") -> Optional["Fingerprint"]: + def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: ret = req.ssl if isinstance(ret, Fingerprint): return ret @@ -955,7 +952,7 @@ def _get_fingerprint(self, req: "ClientRequest") -> Optional["Fingerprint"]: async def _wrap_create_connection( self, *args: Any, - req: "ClientRequest", + req: ClientRequest, timeout: "ClientTimeout", client_error: Type[Exception] = ClientConnectorError, **kwargs: Any, @@ -977,7 +974,7 @@ async def _wrap_create_connection( def _warn_about_tls_in_tls( self, underlying_transport: asyncio.Transport, - req: "ClientRequest", + req: ClientRequest, ) -> None: """Issue a warning if the requested URL has HTTPS scheme.""" if req.request_info.url.scheme != "https": @@ -997,8 +994,8 @@ def _warn_about_tls_in_tls( "This support for TLS in TLS is known to be disabled " "in the stdlib asyncio. This is why you'll probably see " "an error in the log below.\n\n" - "It is possible to enable it via monkeypatching under " - "Python 3.7 or higher. For more details, see:\n" + "It is possible to enable it via monkeypatching. " + "For more details, see:\n" "* https://bugs.python.org/issue37179\n" "* https://github.com/python/cpython/pull/28073\n\n" "You can temporarily patch this as follows:\n" @@ -1014,7 +1011,7 @@ def _warn_about_tls_in_tls( async def _start_tls_connection( self, underlying_transport: asyncio.Transport, - req: "ClientRequest", + req: ClientRequest, timeout: "ClientTimeout", client_error: Type[Exception] = ClientConnectorError, ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: @@ -1066,6 +1063,9 @@ async def _start_tls_connection( f"[{type_err!s}]" ) from type_err else: + if tls_transport is None: + msg = "Failed to start TLS (possibly caused by closing transport)" + raise client_error(req.connection_key, OSError(msg)) tls_proto.connection_made( tls_transport ) # Kick the state machine of the new TLS protocol @@ -1074,7 +1074,7 @@ async def _start_tls_connection( async def _create_direct_connection( self, - req: "ClientRequest", + req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout", *, @@ -1150,7 +1150,7 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: raise last_exc async def _create_proxy_connection( - self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: headers: Dict[str, str] = {} if req.proxy_headers is not None: @@ -1289,7 +1289,7 @@ def path(self) -> str: return self._path async def _create_connection( - self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: try: async with ceil_timeout( @@ -1349,7 +1349,7 @@ def path(self) -> str: return self._path async def _create_connection( - self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout" + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" ) -> ResponseHandler: try: async with ceil_timeout( diff --git a/aiohttp/cookiejar.py b/aiohttp/cookiejar.py index 9a2bf192629..6891c0adf97 100644 --- a/aiohttp/cookiejar.py +++ b/aiohttp/cookiejar.py @@ -1,4 +1,3 @@ -import asyncio import contextlib import datetime import os # noqa @@ -55,7 +54,7 @@ class CookieJar(AbstractCookieJar): MAX_TIME = datetime.datetime.max.replace(tzinfo=datetime.timezone.utc) - MAX_32BIT_TIME = datetime.datetime.utcfromtimestamp(2**31 - 1) + MAX_32BIT_TIME = datetime.datetime.fromtimestamp(2**31 - 1, datetime.timezone.utc) def __init__( self, @@ -64,7 +63,6 @@ def __init__( quote_cookie: bool = True, treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None ) -> None: - self._loop = asyncio.get_running_loop() self._cookies: DefaultDict[Tuple[str, str], SimpleCookie[str]] = defaultdict( SimpleCookie ) @@ -342,7 +340,6 @@ def _parse_date(cls, date_str: str) -> Optional[datetime.datetime]: year = 0 for token_match in cls.DATE_TOKENS_RE.finditer(date_str): - token = token_match.group("token") if not found_time: diff --git a/aiohttp/formdata.py b/aiohttp/formdata.py index cd11755ac33..e159fb6fbcb 100644 --- a/aiohttp/formdata.py +++ b/aiohttp/formdata.py @@ -51,7 +51,6 @@ def add_field( filename: Optional[str] = None, content_transfer_encoding: Optional[str] = None, ) -> None: - if isinstance(value, io.IOBase): self._is_multipart = True elif isinstance(value, (bytes, bytearray, memoryview)): diff --git a/aiohttp/helpers.py b/aiohttp/helpers.py index 260fa163007..fa7b8a2ace7 100644 --- a/aiohttp/helpers.py +++ b/aiohttp/helpers.py @@ -3,6 +3,7 @@ import asyncio import base64 import binascii +import contextlib import dataclasses import datetime import enum @@ -41,7 +42,7 @@ Type, TypeVar, Union, - cast, + get_args, overload, ) from urllib.parse import quote @@ -56,14 +57,8 @@ from .log import client_logger from .typedefs import PathLike # noqa -if sys.version_info >= (3, 8): - from typing import get_args -else: - from typing_extensions import get_args - __all__ = ("BasicAuth", "ChainMapProxy", "ETag") -PY_38 = sys.version_info >= (3, 8) PY_310 = sys.version_info >= (3, 10) COOKIE_MAX_LENGTH = 4096 @@ -114,16 +109,6 @@ def __await__(self) -> Generator[None, None, None]: yield -if PY_38: - iscoroutinefunction = asyncio.iscoroutinefunction -else: - - def iscoroutinefunction(func: Any) -> bool: # type: ignore[misc] - while isinstance(func, functools.partial): - func = func.func - return asyncio.iscoroutinefunction(func) - - json_re = re.compile(r"(?:application/|[\w.-]+/[\w.+-]+?\+)json$", re.IGNORECASE) @@ -229,8 +214,11 @@ def netrc_from_env() -> Optional[netrc.netrc]: except netrc.NetrcParseError as e: client_logger.warning("Could not parse .netrc file: %s", e) except OSError as e: + netrc_exists = False + with contextlib.suppress(OSError): + netrc_exists = netrc_path.is_file() # we couldn't read the file (doesn't exist, permissions, etc.) - if netrc_env or netrc_path.is_file(): + if netrc_env or netrc_exists: # only warn if the environment wanted us to load it, # or it appears like the default file does actually exist client_logger.warning("Could not read .netrc file: %s", e) @@ -244,6 +232,35 @@ class ProxyInfo: proxy_auth: Optional[BasicAuth] +def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth: + """ + Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``. + + :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no + entry is found for the ``host``. + """ + if netrc_obj is None: + raise LookupError("No .netrc file found") + auth_from_netrc = netrc_obj.authenticators(host) + + if auth_from_netrc is None: + raise LookupError(f"No entry for {host!s} found in the `.netrc` file.") + login, account, password = auth_from_netrc + + # TODO(PY311): username = login or account + # Up to python 3.10, account could be None if not specified, + # and login will be empty string if not specified. From 3.11, + # login and account will be empty string if not specified. + username = login if (login or account is None) else account + + # TODO(PY311): Remove this, as password will be empty string + # if not specified + if password is None: + password = "" + + return BasicAuth(username, password) + + def proxies_from_env() -> Dict[str, ProxyInfo]: proxy_urls = { k: URL(v) @@ -261,16 +278,11 @@ def proxies_from_env() -> Dict[str, ProxyInfo]: ) continue if netrc_obj and auth is None: - auth_from_netrc = None if proxy.host is not None: - auth_from_netrc = netrc_obj.authenticators(proxy.host) - if auth_from_netrc is not None: - # auth_from_netrc is a (`user`, `account`, `password`) tuple, - # `user` and `account` both can be username, - # if `user` is None, use `account` - *logins, password = auth_from_netrc - login = logins[0] if logins[0] else logins[-1] - auth = BasicAuth(cast(str, login), cast(str, password)) + try: + auth = basicauth_from_netrc(netrc_obj, proxy.host) + except LookupError: + auth = None ret[proto] = ProxyInfo(proxy, auth) return ret @@ -653,7 +665,8 @@ def __call__(self) -> None: class BaseTimerContext(ContextManager["BaseTimerContext"]): - pass + def assert_timeout(self) -> None: + """Raise TimeoutError if timeout has been exceeded.""" class TimerNoop(BaseTimerContext): @@ -677,6 +690,11 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._tasks: List[asyncio.Task[Any]] = [] self._cancelled = False + def assert_timeout(self) -> None: + """Raise TimeoutError if timer has already been cancelled.""" + if self._cancelled: + raise asyncio.TimeoutError from None + def __enter__(self) -> BaseTimerContext: task = asyncio.current_task(loop=self._loop) @@ -727,7 +745,6 @@ def ceil_timeout( class HeadersMixin: - __slots__ = ("_content_type", "_content_dict", "_stored_content_type") def __init__(self) -> None: diff --git a/aiohttp/http.py b/aiohttp/http.py index a985334aef3..244d71c4197 100644 --- a/aiohttp/http.py +++ b/aiohttp/http.py @@ -1,34 +1,29 @@ import sys from . import __version__ -from .http_exceptions import HttpProcessingError as HttpProcessingError +from .http_exceptions import HttpProcessingError from .http_parser import ( - HeadersParser as HeadersParser, - HttpParser as HttpParser, - HttpRequestParser as HttpRequestParser, - HttpResponseParser as HttpResponseParser, - RawRequestMessage as RawRequestMessage, - RawResponseMessage as RawResponseMessage, + HeadersParser, + HttpParser, + HttpRequestParser, + HttpResponseParser, + RawRequestMessage, + RawResponseMessage, ) from .http_websocket import ( - WS_CLOSED_MESSAGE as WS_CLOSED_MESSAGE, - WS_CLOSING_MESSAGE as WS_CLOSING_MESSAGE, - WS_KEY as WS_KEY, - WebSocketError as WebSocketError, - WebSocketReader as WebSocketReader, - WebSocketWriter as WebSocketWriter, - WSCloseCode as WSCloseCode, - WSMessage as WSMessage, - WSMsgType as WSMsgType, - ws_ext_gen as ws_ext_gen, - ws_ext_parse as ws_ext_parse, -) -from .http_writer import ( - HttpVersion as HttpVersion, - HttpVersion10 as HttpVersion10, - HttpVersion11 as HttpVersion11, - StreamWriter as StreamWriter, + WS_CLOSED_MESSAGE, + WS_CLOSING_MESSAGE, + WS_KEY, + WebSocketError, + WebSocketReader, + WebSocketWriter, + WSCloseCode, + WSMessage, + WSMsgType, + ws_ext_gen, + ws_ext_parse, ) +from .http_writer import HttpVersion, HttpVersion10, HttpVersion11, StreamWriter __all__ = ( "HttpProcessingError", diff --git a/aiohttp/http_exceptions.py b/aiohttp/http_exceptions.py index c885f80f322..728824f856f 100644 --- a/aiohttp/http_exceptions.py +++ b/aiohttp/http_exceptions.py @@ -1,6 +1,7 @@ """Low-level http related exceptions.""" +from textwrap import indent from typing import Optional, Union from .typedefs import _CIMultiDict @@ -35,14 +36,14 @@ def __init__( self.message = message def __str__(self) -> str: - return f"{self.code}, message={self.message!r}" + msg = indent(self.message, " ") + return f"{self.code}, message:\n{msg}" def __repr__(self) -> str: - return f"<{self.__class__.__name__}: {self}>" + return f"<{self.__class__.__name__}: {self.code}, message={self.message!r}>" class BadHttpMessage(HttpProcessingError): - code = 400 message = "Bad Request" @@ -52,7 +53,6 @@ def __init__(self, message: str, *, headers: Optional[_CIMultiDict] = None) -> N class HttpBadRequest(BadHttpMessage): - code = 400 message = "Bad Request" diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index fa214866f3c..e9da4350a95 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -1,13 +1,10 @@ import abc import asyncio -import collections import re import string -import zlib from contextlib import suppress from enum import IntEnum from typing import ( - Any, Generic, List, NamedTuple, @@ -18,7 +15,6 @@ Type, TypeVar, Union, - cast, ) from multidict import CIMultiDict, CIMultiDictProxy, istr @@ -27,6 +23,7 @@ from . import hdrs from .base_protocol import BaseProtocol +from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor from .helpers import NO_EXTENSIONS, BaseTimerContext from .http_exceptions import ( BadHttpMessage, @@ -42,14 +39,6 @@ from .streams import EMPTY_PAYLOAD, StreamReader from .typedefs import RawHeaders -try: - import brotli - - HAS_BROTLI = True -except ImportError: # pragma: no cover - HAS_BROTLI = False - - __all__ = ( "HeadersParser", "HttpParser", @@ -86,27 +75,22 @@ class RawRequestMessage(NamedTuple): url: URL -RawResponseMessage = collections.namedtuple( - "RawResponseMessage", - [ - "version", - "code", - "reason", - "headers", - "raw_headers", - "should_close", - "compression", - "upgrade", - "chunked", - ], -) +class RawResponseMessage(NamedTuple): + version: HttpVersion + code: int + reason: str + headers: CIMultiDictProxy[str] + raw_headers: RawHeaders + should_close: bool + compression: Optional[str] + upgrade: bool + chunked: bool _MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage) class ParseState(IntEnum): - PARSE_NONE = 0 PARSE_LENGTH = 1 PARSE_CHUNKED = 2 @@ -280,7 +264,6 @@ def feed_data( METH_CONNECT: str = hdrs.METH_CONNECT, SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1, ) -> Tuple[List[Tuple[_MsgT, StreamReader]], bool, bytes]: - messages = [] if self._tail: @@ -291,7 +274,6 @@ def feed_data( loop = self.loop while start_pos < data_len: - # read HTTP message (request/response line + headers), \r\n\r\n # and split by lines if self._payload_parser is None and not self._upgraded: @@ -759,7 +741,6 @@ def feed_data( self._chunk_tail = b"" while chunk: - # read next chunk size if self._chunk == ChunkState.PARSE_CHUNKED_SIZE: pos = chunk.find(SEP) @@ -863,34 +844,16 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None: self.encoding = encoding self._started_decoding = False + self.decompressor: Union[BrotliDecompressor, ZLibDecompressor] if encoding == "br": if not HAS_BROTLI: # pragma: no cover raise ContentEncodingError( "Can not decode content-encoding: brotli (br). " "Please install `Brotli`" ) - - class BrotliDecoder: - # Supports both 'brotlipy' and 'Brotli' packages - # since they share an import name. The top branches - # are for 'brotlipy' and bottom branches for 'Brotli' - def __init__(self) -> None: - self._obj = brotli.Decompressor() - - def decompress(self, data: bytes) -> bytes: - if hasattr(self._obj, "decompress"): - return cast(bytes, self._obj.decompress(data)) - return cast(bytes, self._obj.process(data)) - - def flush(self) -> bytes: - if hasattr(self._obj, "flush"): - return cast(bytes, self._obj.flush()) - return b"" - - self.decompressor: Any = BrotliDecoder() + self.decompressor = BrotliDecompressor() else: - zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS - self.decompressor = zlib.decompressobj(wbits=zlib_mode) + self.decompressor = ZLibDecompressor(encoding=encoding) def set_exception(self, exc: BaseException) -> None: self.out.set_exception(exc) @@ -911,10 +874,12 @@ def feed_data(self, chunk: bytes, size: int) -> None: ): # Change the decoder to decompress incorrectly compressed data # Actually we should issue a warning about non-RFC-compliant data. - self.decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS) + self.decompressor = ZLibDecompressor( + encoding=self.encoding, suppress_deflate_header=True + ) try: - chunk = self.decompressor.decompress(chunk) + chunk = self.decompressor.decompress_sync(chunk) except Exception: raise ContentEncodingError( "Can not decode content-encoding: %s" % self.encoding @@ -930,7 +895,7 @@ def feed_eof(self) -> None: if chunk or self.size > 0: self.out.feed_data(chunk, len(chunk)) - if self.encoding == "deflate" and not self.decompressor.eof: + if self.encoding == "deflate" and not self.decompressor.eof: # type: ignore raise ContentEncodingError("deflate") self.out.feed_eof() diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index fe5058cae62..f2e348d651b 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -1,7 +1,6 @@ """WebSocket protocol versions 13 and 8.""" import asyncio -import collections import functools import json import random @@ -10,11 +9,23 @@ import zlib from enum import IntEnum from struct import Struct -from typing import Any, Callable, List, Optional, Pattern, Set, Tuple, Union, cast +from typing import ( + Any, + Callable, + List, + NamedTuple, + Optional, + Pattern, + Set, + Tuple, + Union, + cast, +) from typing_extensions import Final from .base_protocol import BaseProtocol +from .compression_utils import ZLibCompressor, ZLibDecompressor from .helpers import NO_EXTENSIONS from .streams import DataQueue @@ -79,10 +90,12 @@ class WSMsgType(IntEnum): DEFAULT_LIMIT: Final[int] = 2**16 -_WSMessageBase = collections.namedtuple("_WSMessageBase", ["type", "data", "extra"]) - +class WSMessage(NamedTuple): + type: WSMsgType + # To type correctly, this would need some kind of tagged union for each type. + data: Any + extra: Optional[str] -class WSMessage(_WSMessageBase): def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any: """Return parsed JSON data. @@ -270,7 +283,7 @@ def __init__( self._payload_length = 0 self._payload_length_flag = 0 self._compressed: Optional[bool] = None - self._decompressobj: Any = None # zlib.decompressobj actually + self._decompressobj: Optional[ZLibDecompressor] = None self._compress = compress def feed_eof(self) -> None: @@ -290,7 +303,7 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]: def _feed_data(self, data: bytes) -> Tuple[bool, bytes]: for fin, opcode, payload, compressed in self.parse_frame(data): if compressed and not self._decompressobj: - self._decompressobj = zlib.decompressobj(wbits=-zlib.MAX_WBITS) + self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) if opcode == WSMsgType.CLOSE: if len(payload) >= 2: close_code = UNPACK_CLOSE_CODE(payload[:2])[0] @@ -375,8 +388,9 @@ def _feed_data(self, data: bytes) -> Tuple[bool, bytes]: # Decompress process must to be done after all packets # received. if compressed: + assert self._decompressobj is not None self._partial.extend(_WS_DEFLATE_TRAILING) - payload_merged = self._decompressobj.decompress( + payload_merged = self._decompressobj.decompress_sync( self._partial, self._max_msg_size ) if self._decompressobj.unconsumed_tail: @@ -604,16 +618,16 @@ async def _send_frame( if (compress or self.compress) and opcode < 8: if compress: # Do not set self._compress if compressing is for this frame - compressobj = zlib.compressobj(level=zlib.Z_BEST_SPEED, wbits=-compress) + compressobj = ZLibCompressor(level=zlib.Z_BEST_SPEED, wbits=-compress) else: # self.compress if not self._compressobj: - self._compressobj = zlib.compressobj( + self._compressobj = ZLibCompressor( level=zlib.Z_BEST_SPEED, wbits=-self.compress ) compressobj = self._compressobj - message = compressobj.compress(message) - message = message + compressobj.flush( + message = await compressobj.compress(message) + message += compressobj.flush( zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH ) if message.endswith(_WS_DEFLATE_TRAILING): @@ -659,13 +673,13 @@ def _write(self, data: bytes) -> None: raise ConnectionResetError("Cannot write to closing transport") self.transport.write(data) - async def pong(self, message: bytes = b"") -> None: + async def pong(self, message: Union[bytes, str] = b"") -> None: """Send pong message.""" if isinstance(message, str): message = message.encode("utf-8") await self._send_frame(message, WSMsgType.PONG) - async def ping(self, message: bytes = b"") -> None: + async def ping(self, message: Union[bytes, str] = b"") -> None: """Send ping message.""" if isinstance(message, str): message = message.encode("utf-8") @@ -685,7 +699,7 @@ async def send( else: await self._send_frame(message, WSMsgType.TEXT, compress) - async def close(self, code: int = 1000, message: bytes = b"") -> None: + async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None: """Close the websocket, sending the specified code and message.""" if isinstance(message, str): message = message.encode("utf-8") diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index db3d6a04897..8f2d9086b92 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -8,6 +8,7 @@ from .abc import AbstractStreamWriter from .base_protocol import BaseProtocol +from .compression_utils import ZLibCompressor from .helpers import NO_EXTENSIONS __all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11") @@ -35,7 +36,6 @@ def __init__( on_headers_sent: _T_OnHeadersSent = None, ) -> None: self._protocol = protocol - self._transport = protocol.transport self.loop = loop self.length = None @@ -44,7 +44,7 @@ def __init__( self.output_size = 0 self._eof = False - self._compress: Any = None + self._compress: Optional[ZLibCompressor] = None self._drain_waiter = None self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent @@ -52,7 +52,7 @@ def __init__( @property def transport(self) -> Optional[asyncio.Transport]: - return self._transport + return self._protocol.transport @property def protocol(self) -> BaseProtocol: @@ -64,17 +64,16 @@ def enable_chunking(self) -> None: def enable_compression( self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY ) -> None: - zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS - self._compress = zlib.compressobj(wbits=zlib_mode, strategy=strategy) + self._compress = ZLibCompressor(encoding=encoding, strategy=strategy) def _write(self, chunk: bytes) -> None: size = len(chunk) self.buffer_size += size self.output_size += size - - if self._transport is None or self._transport.is_closing(): + transport = self.transport + if not self._protocol.connected or transport is None or transport.is_closing(): raise ConnectionResetError("Cannot write to closing transport") - self._transport.write(chunk) + transport.write(chunk) async def write( self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000 @@ -94,7 +93,7 @@ async def write( chunk = chunk.cast("c") if self._compress is not None: - chunk = self._compress.compress(chunk) + chunk = await self._compress.compress(chunk) if not chunk: return @@ -139,9 +138,9 @@ async def write_eof(self, chunk: bytes = b"") -> None: if self._compress: if chunk: - chunk = self._compress.compress(chunk) + chunk = await self._compress.compress(chunk) - chunk = chunk + self._compress.flush() + chunk += self._compress.flush() if chunk and self.chunked: chunk_len = ("%x\r\n" % len(chunk)).encode("ascii") chunk = chunk_len + chunk + b"\r\n0\r\n\r\n" @@ -159,7 +158,6 @@ async def write_eof(self, chunk: bytes = b"") -> None: await self.drain() self._eof = True - self._transport = None async def drain(self) -> None: """Flush the write buffer. diff --git a/aiohttp/multipart.py b/aiohttp/multipart.py index 942dda507ab..0eecb48ddfc 100644 --- a/aiohttp/multipart.py +++ b/aiohttp/multipart.py @@ -27,6 +27,7 @@ from multidict import CIMultiDict, CIMultiDictProxy, MultiMapping +from .compression_utils import ZLibCompressor, ZLibDecompressor from .hdrs import ( CONTENT_DISPOSITION, CONTENT_ENCODING, @@ -491,15 +492,15 @@ def decode(self, data: bytes) -> bytes: def _decode_content(self, data: bytes) -> bytes: encoding = self.headers.get(CONTENT_ENCODING, "").lower() - - if encoding == "deflate": - return zlib.decompress(data, -zlib.MAX_WBITS) - elif encoding == "gzip": - return zlib.decompress(data, 16 + zlib.MAX_WBITS) - elif encoding == "identity": + if encoding == "identity": return data - else: - raise RuntimeError(f"unknown content encoding: {encoding}") + if encoding in {"deflate", "gzip"}: + return ZLibDecompressor( + encoding=encoding, + suppress_deflate_header=True, + ).decompress_sync(data) + + raise RuntimeError(f"unknown content encoding: {encoding}") def _decode_content_transfer(self, data: bytes) -> bytes: encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower() @@ -976,7 +977,7 @@ class MultipartPayloadWriter: def __init__(self, writer: Any) -> None: self._writer = writer self._encoding: Optional[str] = None - self._compress: Any = None + self._compress: Optional[ZLibCompressor] = None self._encoding_buffer: Optional[bytearray] = None def enable_encoding(self, encoding: str) -> None: @@ -989,8 +990,11 @@ def enable_encoding(self, encoding: str) -> None: def enable_compression( self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY ) -> None: - zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else -zlib.MAX_WBITS - self._compress = zlib.compressobj(wbits=zlib_mode, strategy=strategy) + self._compress = ZLibCompressor( + encoding=encoding, + suppress_deflate_header=True, + strategy=strategy, + ) async def write_eof(self) -> None: if self._compress is not None: @@ -1006,7 +1010,7 @@ async def write_eof(self) -> None: async def write(self, chunk: bytes) -> None: if self._compress is not None: if chunk: - chunk = self._compress.compress(chunk) + chunk = await self._compress.compress(chunk) if not chunk: return diff --git a/aiohttp/payload.py b/aiohttp/payload.py index 92db73b4379..11f3f2db7b0 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -132,7 +132,6 @@ def register( class Payload(ABC): - _default_content_type: str = "application/octet-stream" _size: Optional[int] = None @@ -254,7 +253,6 @@ def __init__( content_type: Optional[str] = None, **kwargs: Any, ) -> None: - if encoding is None: if content_type is None: real_encoding = "utf-8" @@ -318,7 +316,6 @@ def __init__( content_type: Optional[str] = None, **kwargs: Any, ) -> None: - if encoding is None: if content_type is None: encoding = "utf-8" @@ -391,7 +388,6 @@ def __init__( *args: Any, **kwargs: Any, ) -> None: - super().__init__( dumps(value).encode(encoding), content_type=content_type, @@ -414,14 +410,13 @@ def __init__( class AsyncIterablePayload(Payload): - _iter: Optional[_AsyncIterator] = None def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None: if not isinstance(value, AsyncIterable): raise TypeError( "value argument must support " - "collections.abc.AsyncIterablebe interface, " + "collections.abc.AsyncIterable interface, " "got {!r}".format(type(value)) ) diff --git a/aiohttp/pytest_plugin.py b/aiohttp/pytest_plugin.py index d08327e0311..8bbe46f559a 100644 --- a/aiohttp/pytest_plugin.py +++ b/aiohttp/pytest_plugin.py @@ -2,7 +2,7 @@ import contextlib import inspect import warnings -from typing import Any, Awaitable, Callable, Dict, Generator, Optional, Type, Union +from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Type, Union import pytest @@ -22,14 +22,11 @@ try: import uvloop except ImportError: # pragma: no cover - uvloop = None - -try: - import tokio -except ImportError: # pragma: no cover - tokio = None + uvloop = None # type: ignore[assignment] AiohttpClient = Callable[[Union[Application, BaseTestServer]], Awaitable[TestClient]] +AiohttpRawServer = Callable[[Application], Awaitable[RawTestServer]] +AiohttpServer = Callable[[Application], Awaitable[TestServer]] def pytest_addoption(parser): # type: ignore[no-untyped-def] @@ -43,7 +40,7 @@ def pytest_addoption(parser): # type: ignore[no-untyped-def] "--aiohttp-loop", action="store", default="pyloop", - help="run tests with specific loop: pyloop, uvloop, tokio or all", + help="run tests with specific loop: pyloop, uvloop or all", ) parser.addoption( "--aiohttp-enable-loop-debug", @@ -198,16 +195,14 @@ def pytest_generate_tests(metafunc): # type: ignore[no-untyped-def] return loops = metafunc.config.option.aiohttp_loop + avail_factories: Dict[str, Type[asyncio.AbstractEventLoopPolicy]] avail_factories = {"pyloop": asyncio.DefaultEventLoopPolicy} if uvloop is not None: # pragma: no cover avail_factories["uvloop"] = uvloop.EventLoopPolicy - if tokio is not None: # pragma: no cover - avail_factories["tokio"] = tokio.EventLoopPolicy - if loops == "all": - loops = "pyloop,uvloop?,tokio?" + loops = "pyloop,uvloop?" factories = {} # type: ignore[var-annotated] for name in loops.split(","): @@ -250,13 +245,13 @@ def proactor_loop(): # type: ignore[no-untyped-def] @pytest.fixture -def aiohttp_unused_port(): # type: ignore[no-untyped-def] +def aiohttp_unused_port() -> Callable[[], int]: """Return a port that is unused on the current host.""" return _unused_port @pytest.fixture -def aiohttp_server(loop): # type: ignore[no-untyped-def] +def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]: """Factory to create a TestServer instance, given an app. aiohttp_server(app, **kwargs) @@ -279,7 +274,7 @@ async def finalize() -> None: @pytest.fixture -def aiohttp_raw_server(loop): # type: ignore[no-untyped-def] +def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawServer]: """Factory to create a RawTestServer instance, given a web handler. aiohttp_raw_server(handler, **kwargs) @@ -331,7 +326,7 @@ def test_login(aiohttp_client): @pytest.fixture def aiohttp_client( loop: asyncio.AbstractEventLoop, aiohttp_client_cls: Type[TestClient] -) -> Generator[AiohttpClient, None, None]: +) -> Iterator[AiohttpClient]: """Factory to create a TestClient instance. aiohttp_client(app, **kwargs) diff --git a/aiohttp/streams.py b/aiohttp/streams.py index 259609209ae..1a2c5147fc3 100644 --- a/aiohttp/streams.py +++ b/aiohttp/streams.py @@ -6,7 +6,7 @@ from typing_extensions import Final from .base_protocol import BaseProtocol -from .helpers import BaseTimerContext, set_exception, set_result +from .helpers import BaseTimerContext, TimerNoop, set_exception, set_result from .log import internal_logger try: # pragma: no cover @@ -122,7 +122,7 @@ def __init__( self._waiter: Optional[asyncio.Future[None]] = None self._eof_waiter: Optional[asyncio.Future[None]] = None self._exception: Optional[BaseException] = None - self._timer = timer + self._timer = TimerNoop() if timer is None else timer self._eof_callbacks: List[Callable[[], None]] = [] def __repr__(self) -> str: @@ -297,10 +297,7 @@ async def _wait(self, func_name: str) -> None: waiter = self._waiter = self._loop.create_future() try: - if self._timer: - with self._timer: - await waiter - else: + with self._timer: await waiter finally: self._waiter = None @@ -477,8 +474,9 @@ def _read_nowait_chunk(self, n: int) -> bytes: def _read_nowait(self, n: int) -> bytes: """Read not more than n bytes, or whole buffer if n == -1""" - chunks = [] + self._timer.assert_timeout() + chunks = [] while self._buffer: chunk = self._read_nowait_chunk(n) chunks.append(chunk) diff --git a/aiohttp/test_utils.py b/aiohttp/test_utils.py index e0aacbb68fc..4d1e9a7f37c 100644 --- a/aiohttp/test_utils.py +++ b/aiohttp/test_utils.py @@ -21,7 +21,7 @@ Union, cast, ) -from unittest import mock +from unittest import IsolatedAsyncioTestCase, mock from aiosignal import Signal from multidict import CIMultiDict, CIMultiDictProxy @@ -34,8 +34,9 @@ from .abc import AbstractCookieJar from .client_reqrep import ClientResponse from .client_ws import ClientWebSocketResponse -from .helpers import _SENTINEL, PY_38, sentinel +from .helpers import _SENTINEL, sentinel from .http import HttpVersion, RawRequestMessage +from .typedefs import StrOrURL from .web import ( Application, AppRunner, @@ -53,11 +54,6 @@ else: SSLContext = None -if PY_38: - from unittest import IsolatedAsyncioTestCase as TestCase -else: - from asynctest import TestCase # type: ignore[no-redef] - REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" @@ -115,7 +111,7 @@ async def start_server(self, **kwargs: Any) -> None: if self.runner: return self._ssl = kwargs.pop("ssl", None) - self.runner = await self._make_runner(**kwargs) + self.runner = await self._make_runner(handler_cancellation=True, **kwargs) await self.runner.setup() if not self.port: self.port = 0 @@ -148,14 +144,14 @@ async def start_server(self, **kwargs: Any) -> None: async def _make_runner(self, **kwargs: Any) -> BaseRunner: pass - def make_url(self, path: str) -> URL: + def make_url(self, path: StrOrURL) -> URL: assert self._root is not None url = URL(path) if not self.skip_url_asserts: assert not url.is_absolute() return self._root.join(url) else: - return URL(str(self._root) + path) + return URL(str(self._root) + str(path)) @property def started(self) -> bool: @@ -304,16 +300,20 @@ def session(self) -> ClientSession: """ return self._session - def make_url(self, path: str) -> URL: + def make_url(self, path: StrOrURL) -> URL: return self._server.make_url(path) - async def _request(self, method: str, path: str, **kwargs: Any) -> ClientResponse: + async def _request( + self, method: str, path: StrOrURL, **kwargs: Any + ) -> ClientResponse: resp = await self._session.request(method, self.make_url(path), **kwargs) # save it to close later self._responses.append(resp) return resp - def request(self, method: str, path: str, **kwargs: Any) -> _RequestContextManager: + def request( + self, method: str, path: StrOrURL, **kwargs: Any + ) -> _RequestContextManager: """Routes a request to tested http server. The interface is identical to aiohttp.ClientSession.request, @@ -323,35 +323,35 @@ def request(self, method: str, path: str, **kwargs: Any) -> _RequestContextManag """ return _RequestContextManager(self._request(method, path, **kwargs)) - def get(self, path: str, **kwargs: Any) -> _RequestContextManager: + def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP GET request.""" return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) - def post(self, path: str, **kwargs: Any) -> _RequestContextManager: + def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP POST request.""" return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) - def options(self, path: str, **kwargs: Any) -> _RequestContextManager: + def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP OPTIONS request.""" return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs)) - def head(self, path: str, **kwargs: Any) -> _RequestContextManager: + def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP HEAD request.""" return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) - def put(self, path: str, **kwargs: Any) -> _RequestContextManager: + def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP PUT request.""" return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) - def patch(self, path: str, **kwargs: Any) -> _RequestContextManager: + def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP PATCH request.""" return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs)) - def delete(self, path: str, **kwargs: Any) -> _RequestContextManager: + def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: """Perform an HTTP PATCH request.""" return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs)) - def ws_connect(self, path: str, **kwargs: Any) -> _WSRequestContextManager: + def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager: """Initiate websocket connection. The api corresponds to aiohttp.ClientSession.ws_connect. @@ -359,7 +359,9 @@ def ws_connect(self, path: str, **kwargs: Any) -> _WSRequestContextManager: """ return _WSRequestContextManager(self._ws_connect(path, **kwargs)) - async def _ws_connect(self, path: str, **kwargs: Any) -> ClientWebSocketResponse: + async def _ws_connect( + self, path: StrOrURL, **kwargs: Any + ) -> ClientWebSocketResponse: ws = await self._session.ws_connect(self.make_url(path), **kwargs) self._websockets.append(ws) return ws @@ -398,14 +400,12 @@ async def __aexit__( await self.close() -class AioHTTPTestCase(TestCase): +class AioHTTPTestCase(IsolatedAsyncioTestCase, ABC): """A base class to allow for unittest web applications using aiohttp. Provides the following: * self.client (aiohttp.test_utils.TestClient): an aiohttp test client. - * self.loop (asyncio.BaseEventLoop): the event loop in which the - application and server are running. * self.app (aiohttp.web.Application): the application returned by self.get_application() @@ -413,45 +413,22 @@ class AioHTTPTestCase(TestCase): execute function on the test client using asynchronous methods. """ + @abstractmethod async def get_application(self) -> Application: """Get application. - This method should be overridden - to return the aiohttp.web.Application + This method should be overridden to return the aiohttp.web.Application object to test. """ - return self.get_app() - - def get_app(self) -> Application: - """Obsolete method used to constructing web application. - - Use .get_application() coroutine instead. - """ - raise RuntimeError("Did you forget to define get_application()?") - - def setUp(self) -> None: - if not PY_38: - asyncio.get_event_loop().run_until_complete(self.asyncSetUp()) async def asyncSetUp(self) -> None: - self.loop = asyncio.get_running_loop() - return await self.setUpAsync() - - async def setUpAsync(self) -> None: self.app = await self.get_application() self.server = await self.get_server(self.app) self.client = await self.get_client(self.server) await self.client.start_server() - def tearDown(self) -> None: - if not PY_38: - self.loop.run_until_complete(self.asyncTearDown()) - async def asyncTearDown(self) -> None: - return await self.tearDownAsync() - - async def tearDownAsync(self) -> None: await self.client.close() async def get_server(self, app: Application) -> TestServer: @@ -497,16 +474,7 @@ def setup_test_loop( asyncio.set_event_loop(loop) if sys.platform != "win32" and not skip_watcher: policy = asyncio.get_event_loop_policy() - watcher: asyncio.AbstractChildWatcher - try: # Python >= 3.8 - # Refs: - # * https://github.com/pytest-dev/pytest-xdist/issues/620 - # * https://stackoverflow.com/a/58614689/595220 - # * https://bugs.python.org/issue35621 - # * https://github.com/python/cpython/pull/14344 - watcher = asyncio.ThreadedChildWatcher() - except AttributeError: # Python < 3.8 - watcher = asyncio.SafeChildWatcher() + watcher = asyncio.ThreadedChildWatcher() watcher.attach_loop(loop) with contextlib.suppress(NotImplementedError): policy.set_child_watcher(watcher) diff --git a/aiohttp/web.py b/aiohttp/web.py index e3e75779c6f..f87d57988be 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -20,137 +20,118 @@ ) from .abc import AbstractAccessLogger -from .helpers import AppKey as AppKey +from .helpers import AppKey from .log import access_logger from .typedefs import PathLike -from .web_app import Application as Application, CleanupError as CleanupError +from .web_app import Application, CleanupError from .web_exceptions import ( - HTTPAccepted as HTTPAccepted, - HTTPBadGateway as HTTPBadGateway, - HTTPBadRequest as HTTPBadRequest, - HTTPClientError as HTTPClientError, - HTTPConflict as HTTPConflict, - HTTPCreated as HTTPCreated, - HTTPError as HTTPError, - HTTPException as HTTPException, - HTTPExpectationFailed as HTTPExpectationFailed, - HTTPFailedDependency as HTTPFailedDependency, - HTTPForbidden as HTTPForbidden, - HTTPFound as HTTPFound, - HTTPGatewayTimeout as HTTPGatewayTimeout, - HTTPGone as HTTPGone, - HTTPInsufficientStorage as HTTPInsufficientStorage, - HTTPInternalServerError as HTTPInternalServerError, - HTTPLengthRequired as HTTPLengthRequired, - HTTPMethodNotAllowed as HTTPMethodNotAllowed, - HTTPMisdirectedRequest as HTTPMisdirectedRequest, - HTTPMovedPermanently as HTTPMovedPermanently, - HTTPMultipleChoices as HTTPMultipleChoices, - HTTPNetworkAuthenticationRequired as HTTPNetworkAuthenticationRequired, - HTTPNoContent as HTTPNoContent, - HTTPNonAuthoritativeInformation as HTTPNonAuthoritativeInformation, - HTTPNotAcceptable as HTTPNotAcceptable, - HTTPNotExtended as HTTPNotExtended, - HTTPNotFound as HTTPNotFound, - HTTPNotImplemented as HTTPNotImplemented, - HTTPNotModified as HTTPNotModified, - HTTPOk as HTTPOk, - HTTPPartialContent as HTTPPartialContent, - HTTPPaymentRequired as HTTPPaymentRequired, - HTTPPermanentRedirect as HTTPPermanentRedirect, - HTTPPreconditionFailed as HTTPPreconditionFailed, - HTTPPreconditionRequired as HTTPPreconditionRequired, - HTTPProxyAuthenticationRequired as HTTPProxyAuthenticationRequired, - HTTPRedirection as HTTPRedirection, - HTTPRequestEntityTooLarge as HTTPRequestEntityTooLarge, - HTTPRequestHeaderFieldsTooLarge as HTTPRequestHeaderFieldsTooLarge, - HTTPRequestRangeNotSatisfiable as HTTPRequestRangeNotSatisfiable, - HTTPRequestTimeout as HTTPRequestTimeout, - HTTPRequestURITooLong as HTTPRequestURITooLong, - HTTPResetContent as HTTPResetContent, - HTTPSeeOther as HTTPSeeOther, - HTTPServerError as HTTPServerError, - HTTPServiceUnavailable as HTTPServiceUnavailable, - HTTPSuccessful as HTTPSuccessful, - HTTPTemporaryRedirect as HTTPTemporaryRedirect, - HTTPTooManyRequests as HTTPTooManyRequests, - HTTPUnauthorized as HTTPUnauthorized, - HTTPUnavailableForLegalReasons as HTTPUnavailableForLegalReasons, - HTTPUnprocessableEntity as HTTPUnprocessableEntity, - HTTPUnsupportedMediaType as HTTPUnsupportedMediaType, - HTTPUpgradeRequired as HTTPUpgradeRequired, - HTTPUseProxy as HTTPUseProxy, - HTTPVariantAlsoNegotiates as HTTPVariantAlsoNegotiates, - HTTPVersionNotSupported as HTTPVersionNotSupported, + HTTPAccepted, + HTTPBadGateway, + HTTPBadRequest, + HTTPClientError, + HTTPConflict, + HTTPCreated, + HTTPError, + HTTPException, + HTTPExpectationFailed, + HTTPFailedDependency, + HTTPForbidden, + HTTPFound, + HTTPGatewayTimeout, + HTTPGone, + HTTPInsufficientStorage, + HTTPInternalServerError, + HTTPLengthRequired, + HTTPMethodNotAllowed, + HTTPMisdirectedRequest, + HTTPMove, + HTTPMovedPermanently, + HTTPMultipleChoices, + HTTPNetworkAuthenticationRequired, + HTTPNoContent, + HTTPNonAuthoritativeInformation, + HTTPNotAcceptable, + HTTPNotExtended, + HTTPNotFound, + HTTPNotImplemented, + HTTPNotModified, + HTTPOk, + HTTPPartialContent, + HTTPPaymentRequired, + HTTPPermanentRedirect, + HTTPPreconditionFailed, + HTTPPreconditionRequired, + HTTPProxyAuthenticationRequired, + HTTPRedirection, + HTTPRequestEntityTooLarge, + HTTPRequestHeaderFieldsTooLarge, + HTTPRequestRangeNotSatisfiable, + HTTPRequestTimeout, + HTTPRequestURITooLong, + HTTPResetContent, + HTTPSeeOther, + HTTPServerError, + HTTPServiceUnavailable, + HTTPSuccessful, + HTTPTemporaryRedirect, + HTTPTooManyRequests, + HTTPUnauthorized, + HTTPUnavailableForLegalReasons, + HTTPUnprocessableEntity, + HTTPUnsupportedMediaType, + HTTPUpgradeRequired, + HTTPUseProxy, + HTTPVariantAlsoNegotiates, + HTTPVersionNotSupported, ) -from .web_fileresponse import FileResponse as FileResponse +from .web_fileresponse import FileResponse from .web_log import AccessLogger -from .web_middlewares import ( - middleware as middleware, - normalize_path_middleware as normalize_path_middleware, -) -from .web_protocol import ( - PayloadAccessError as PayloadAccessError, - RequestHandler as RequestHandler, - RequestPayloadError as RequestPayloadError, -) -from .web_request import ( - BaseRequest as BaseRequest, - FileField as FileField, - Request as Request, -) -from .web_response import ( - ContentCoding as ContentCoding, - Response as Response, - StreamResponse as StreamResponse, - json_response as json_response, -) +from .web_middlewares import middleware, normalize_path_middleware +from .web_protocol import PayloadAccessError, RequestHandler, RequestPayloadError +from .web_request import BaseRequest, FileField, Request +from .web_response import ContentCoding, Response, StreamResponse, json_response from .web_routedef import ( - AbstractRouteDef as AbstractRouteDef, - RouteDef as RouteDef, - RouteTableDef as RouteTableDef, - StaticDef as StaticDef, - delete as delete, - get as get, - head as head, - options as options, - patch as patch, - post as post, - put as put, - route as route, - static as static, - view as view, + AbstractRouteDef, + RouteDef, + RouteTableDef, + StaticDef, + delete, + get, + head, + options, + patch, + post, + put, + route, + static, + view, ) from .web_runner import ( - AppRunner as AppRunner, - BaseRunner as BaseRunner, - BaseSite as BaseSite, - GracefulExit as GracefulExit, - NamedPipeSite as NamedPipeSite, - ServerRunner as ServerRunner, - SockSite as SockSite, - TCPSite as TCPSite, - UnixSite as UnixSite, + AppRunner, + BaseRunner, + BaseSite, + GracefulExit, + NamedPipeSite, + ServerRunner, + SockSite, + TCPSite, + UnixSite, ) -from .web_server import Server as Server +from .web_server import Server from .web_urldispatcher import ( - AbstractResource as AbstractResource, - AbstractRoute as AbstractRoute, - DynamicResource as DynamicResource, - PlainResource as PlainResource, - PrefixedSubAppResource as PrefixedSubAppResource, - Resource as Resource, - ResourceRoute as ResourceRoute, - StaticResource as StaticResource, - UrlDispatcher as UrlDispatcher, - UrlMappingMatchInfo as UrlMappingMatchInfo, - View as View, -) -from .web_ws import ( - WebSocketReady as WebSocketReady, - WebSocketResponse as WebSocketResponse, - WSMsgType as WSMsgType, + AbstractResource, + AbstractRoute, + DynamicResource, + PlainResource, + PrefixedSubAppResource, + Resource, + ResourceRoute, + StaticResource, + UrlDispatcher, + UrlMappingMatchInfo, + View, ) +from .web_ws import WebSocketReady, WebSocketResponse, WSMsgType __all__ = ( # web_app @@ -177,6 +158,7 @@ "HTTPLengthRequired", "HTTPMethodNotAllowed", "HTTPMisdirectedRequest", + "HTTPMove", "HTTPMovedPermanently", "HTTPMultipleChoices", "HTTPNetworkAuthenticationRequired", @@ -307,6 +289,7 @@ async def _run_app( handle_signals: bool = True, reuse_address: Optional[bool] = None, reuse_port: Optional[bool] = None, + handler_cancellation: bool = False, ) -> None: # An internal function to actually do all dirty job for application running if asyncio.iscoroutine(app): @@ -321,6 +304,7 @@ async def _run_app( access_log_format=access_log_format, access_log=access_log, keepalive_timeout=keepalive_timeout, + handler_cancellation=handler_cancellation, ) await runner.setup() @@ -425,15 +409,8 @@ async def _run_app( ) # sleep forever by 1 hour intervals, - # on Windows before Python 3.8 wake up every 1 second to handle - # Ctrl+C smoothly - if sys.platform == "win32" and sys.version_info < (3, 8): - delay = 1 - else: - delay = 3600 - while True: - await asyncio.sleep(delay) + await asyncio.sleep(3600) finally: await runner.cleanup() @@ -481,6 +458,7 @@ def run_app( handle_signals: bool = True, reuse_address: Optional[bool] = None, reuse_port: Optional[bool] = None, + handler_cancellation: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: """Run an app locally""" @@ -513,6 +491,7 @@ def run_app( handle_signals=handle_signals, reuse_address=reuse_address, reuse_port=reuse_port, + handler_cancellation=handler_cancellation, ) ) diff --git a/aiohttp/web_app.py b/aiohttp/web_app.py index 438b1049eeb..80956831acc 100644 --- a/aiohttp/web_app.py +++ b/aiohttp/web_app.py @@ -100,7 +100,6 @@ def __init__( client_max_size: int = 1024**2, debug: Any = ..., # mypy doesn't support ellipsis ) -> None: - if debug is not ...: warnings.warn( "debug argument is no-op since 4.0 " "and scheduled for removal in 5.0", diff --git a/aiohttp/web_exceptions.py b/aiohttp/web_exceptions.py index 17383c2eb78..332ca9fa565 100644 --- a/aiohttp/web_exceptions.py +++ b/aiohttp/web_exceptions.py @@ -77,7 +77,6 @@ class HTTPException(CookieMixin, Exception): - # You should set in subclasses: # status = 200 diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index cdb0fe58923..ce85eeb6a69 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -283,4 +283,4 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter try: return await self._sendfile(request, fobj, offset, count) finally: - await loop.run_in_executor(None, fobj.close) + await asyncio.shield(loop.run_in_executor(None, fobj.close)) diff --git a/aiohttp/web_log.py b/aiohttp/web_log.py index 0fd862e6a84..633e9e3ae6b 100644 --- a/aiohttp/web_log.py +++ b/aiohttp/web_log.py @@ -189,6 +189,9 @@ def _format_line( return [(key, method(request, response, time)) for key, method in self._methods] def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None: + if not self.logger.isEnabledFor(logging.INFO): + # Avoid formatting the log line if it will not be emitted. + return try: fmt_info = self._format_line(request, response, time) diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 45b6f423fc1..27c815a4461 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -313,6 +313,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: super().connection_lost(exc) + # Grab value before setting _manager to None. + handler_cancellation = self._manager.handler_cancellation + self._manager = None self._force_close = True self._request_factory = None @@ -330,6 +333,9 @@ def connection_lost(self, exc: Optional[BaseException]) -> None: if self._waiter is not None: self._waiter.cancel() + if handler_cancellation and self._task_handler is not None: + self._task_handler.cancel() + self._task_handler = None if self._payload_parser is not None: diff --git a/aiohttp/web_request.py b/aiohttp/web_request.py index 57838c2ee10..ae74e265f86 100644 --- a/aiohttp/web_request.py +++ b/aiohttp/web_request.py @@ -114,7 +114,6 @@ class FileField: class BaseRequest(MutableMapping[str, Any], HeadersMixin): - POST_METHODS = { hdrs.METH_PATCH, hdrs.METH_POST, @@ -846,7 +845,6 @@ async def wait_for_disconnection(self) -> None: class Request(BaseRequest): - __slots__ = ("_match_info",) def __init__(self, *args: Any, **kwargs: Any) -> None: diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 7788ce12d13..1e47454e997 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -6,10 +6,8 @@ import math import time import warnings -import zlib from concurrent.futures import Executor from http import HTTPStatus -from http.cookies import Morsel from typing import ( TYPE_CHECKING, Any, @@ -25,9 +23,9 @@ from . import hdrs, payload from .abc import AbstractStreamWriter +from .compression_utils import ZLibCompressor from .helpers import ( ETAG_ANY, - PY_38, QUOTED_ETAG_RE, CookieMixin, ETag, @@ -53,12 +51,6 @@ BaseClass = collections.abc.MutableMapping -if not PY_38: - # allow samesite to be used in python < 3.8 - # already permitted in python 3.8, see https://bugs.python.org/issue29613 - Morsel._reserved["samesite"] = "SameSite" # type: ignore[attr-defined] - - class ContentCoding(enum.Enum): # The content codings that we have support for. # @@ -75,7 +67,6 @@ class ContentCoding(enum.Enum): class StreamResponse(BaseClass, HeadersMixin, CookieMixin): - __slots__ = ( "_length_check", "_body", @@ -495,7 +486,6 @@ def __eq__(self, other: object) -> bool: class Response(StreamResponse): - __slots__ = ( "_body_payload", "_compressed_body", @@ -578,12 +568,7 @@ def body(self) -> Optional[Union[bytes, Payload]]: return self._body @body.setter - def body( - self, - body: bytes, - CONTENT_TYPE: istr = hdrs.CONTENT_TYPE, - CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH, - ) -> None: + def body(self, body: bytes) -> None: if body is None: self._body: Optional[bytes] = None self._body_payload: bool = False @@ -600,19 +585,13 @@ def body( headers = self._headers - # set content-length header if needed - if not self._chunked and CONTENT_LENGTH not in headers: - size = body.size - if size is not None: - headers[CONTENT_LENGTH] = str(size) - # set content-type - if CONTENT_TYPE not in headers: - headers[CONTENT_TYPE] = body.content_type + if hdrs.CONTENT_TYPE not in headers: + headers[hdrs.CONTENT_TYPE] = body.content_type # copy payload headers if body.headers: - for (key, value) in body.headers.items(): + for key, value in body.headers.items(): if key not in headers: headers[key] = value @@ -686,21 +665,16 @@ async def write_eof(self, data: bytes = b"") -> None: async def _start(self, request: "BaseRequest") -> AbstractStreamWriter: if not self._chunked and hdrs.CONTENT_LENGTH not in self._headers: - if not self._body_payload: - if self._body is not None: - self._headers[hdrs.CONTENT_LENGTH] = str(len(self._body)) - else: - self._headers[hdrs.CONTENT_LENGTH] = "0" + if self._body_payload: + size = cast(Payload, self._body).size + if size is not None: + self._headers[hdrs.CONTENT_LENGTH] = str(size) + else: + body_len = len(self._body) if self._body else "0" + self._headers[hdrs.CONTENT_LENGTH] = str(body_len) return await super()._start(request) - def _compress_body(self, zlib_mode: int) -> None: - assert zlib_mode > 0 - compressobj = zlib.compressobj(wbits=zlib_mode) - body_in = self._body - assert body_in is not None - self._compressed_body = compressobj.compress(body_in) + compressobj.flush() - async def _do_start_compression(self, coding: ContentCoding) -> None: if self._body_payload or self._chunked: return await super()._do_start_compression(coding) @@ -708,26 +682,26 @@ async def _do_start_compression(self, coding: ContentCoding) -> None: if coding != ContentCoding.identity: # Instead of using _payload_writer.enable_compression, # compress the whole body - zlib_mode = ( - 16 + zlib.MAX_WBITS if coding == ContentCoding.gzip else zlib.MAX_WBITS + compressor = ZLibCompressor( + encoding=str(coding.value), + max_sync_chunk_size=self._zlib_executor_size, + executor=self._zlib_executor, ) - body_in = self._body - assert body_in is not None - if ( - self._zlib_executor_size is not None - and len(body_in) > self._zlib_executor_size - ): - await asyncio.get_event_loop().run_in_executor( - self._zlib_executor, self._compress_body, zlib_mode + assert self._body is not None + if self._zlib_executor_size is None and len(self._body) > 1024 * 1024: + warnings.warn( + "Synchronous compression of large response bodies " + f"({len(self._body)} bytes) might block the async event loop. " + "Consider providing a custom value to zlib_executor_size/" + "zlib_executor response properties or disabling compression on it." ) - else: - self._compress_body(zlib_mode) - - body_out = self._compressed_body - assert body_out is not None + self._compressed_body = ( + await compressor.compress(self._body) + compressor.flush() + ) + assert self._compressed_body is not None self._headers[hdrs.CONTENT_ENCODING] = coding.value - self._headers[hdrs.CONTENT_LENGTH] = str(len(body_out)) + self._headers[hdrs.CONTENT_LENGTH] = str(len(self._compressed_body)) def json_response( diff --git a/aiohttp/web_runner.py b/aiohttp/web_runner.py index 4b1408c31a6..3063dce36ec 100644 --- a/aiohttp/web_runner.py +++ b/aiohttp/web_runner.py @@ -2,6 +2,7 @@ import signal import socket from abc import ABC, abstractmethod +from contextlib import suppress from typing import Any, List, Optional, Set, Type from yarl import URL @@ -80,11 +81,23 @@ async def stop(self) -> None: # named pipes do not have wait_closed property if hasattr(self._server, "wait_closed"): await self._server.wait_closed() + + # Wait for pending tasks for a given time limit. + with suppress(asyncio.TimeoutError): + await asyncio.wait_for( + self._wait(asyncio.current_task()), timeout=self._shutdown_timeout + ) + await self._runner.shutdown() assert self._runner.server await self._runner.server.shutdown(self._shutdown_timeout) self._runner._unreg_site(self) + async def _wait(self, parent_task: Optional["asyncio.Task[object]"]) -> None: + exclude = self._runner.starting_tasks | {asyncio.current_task(), parent_task} + while tasks := asyncio.all_tasks() - exclude: + await asyncio.wait(tasks) + class TCPSite(BaseSite): __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port") @@ -247,7 +260,7 @@ async def start(self) -> None: class BaseRunner(ABC): - __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites") + __slots__ = ("starting_tasks", "_handle_signals", "_kwargs", "_server", "_sites") def __init__(self, *, handle_signals: bool = False, **kwargs: Any) -> None: self._handle_signals = handle_signals @@ -287,6 +300,11 @@ async def setup(self) -> None: pass self._server = await self._make_server() + # On shutdown we want to avoid waiting on tasks which run forever. + # It's very likely that all tasks which run forever will have been created by + # the time we have completed the application startup (in self._make_server()), + # so we just record all running tasks here and exclude them later. + self.starting_tasks = asyncio.all_tasks() @abstractmethod async def shutdown(self) -> None: @@ -368,7 +386,6 @@ def __init__( access_log_class: Type[AbstractAccessLogger] = AccessLogger, **kwargs: Any, ) -> None: - if not isinstance(app, Application): raise TypeError( "The first argument should be web.Application " diff --git a/aiohttp/web_server.py b/aiohttp/web_server.py index 40211463f37..a3d658afbff 100644 --- a/aiohttp/web_server.py +++ b/aiohttp/web_server.py @@ -19,6 +19,7 @@ def __init__( *, request_factory: Optional[_RequestFactory] = None, debug: Optional[bool] = None, + handler_cancellation: bool = False, **kwargs: Any, ) -> None: if debug is not None: @@ -33,6 +34,7 @@ def __init__( self.requests_count = 0 self.request_handler = handler self.request_factory = request_factory or self._make_request + self.handler_cancellation = handler_cancellation @property def connections(self) -> List[RequestHandler]: diff --git a/aiohttp/web_urldispatcher.py b/aiohttp/web_urldispatcher.py index c9b8c6e4ad4..81e152b8b9d 100644 --- a/aiohttp/web_urldispatcher.py +++ b/aiohttp/web_urldispatcher.py @@ -1,4 +1,5 @@ import abc +import asyncio import base64 import hashlib import keyword @@ -35,7 +36,7 @@ from . import hdrs from .abc import AbstractMatchInfo, AbstractRouter, AbstractView -from .helpers import DEBUG, iscoroutinefunction +from .helpers import DEBUG from .http import HttpVersion11 from .typedefs import Handler, PathLike from .web_exceptions import ( @@ -161,11 +162,10 @@ def __init__( expect_handler: Optional[_ExpectHandler] = None, resource: Optional[AbstractResource] = None, ) -> None: - if expect_handler is None: expect_handler = _default_expect_handler - assert iscoroutinefunction( + assert asyncio.iscoroutinefunction( expect_handler ), f"Coroutine is expected, got {expect_handler!r}" @@ -173,7 +173,7 @@ def __init__( if not HTTP_METHOD_RE.match(method): raise ValueError(f"{method} is not allowed HTTP method") - if iscoroutinefunction(handler): + if asyncio.iscoroutinefunction(handler): pass elif isinstance(handler, type) and issubclass(handler, AbstractView): pass @@ -325,7 +325,6 @@ def add_route( *, expect_handler: Optional[_ExpectHandler] = None, ) -> "ResourceRoute": - for route_obj in self._routes: if route_obj.method == method or route_obj.method == hdrs.METH_ANY: raise RuntimeError( @@ -415,7 +414,6 @@ def __repr__(self) -> str: class DynamicResource(Resource): - DYN = re.compile(r"\{(?P[_a-zA-Z][_a-zA-Z0-9]*)\}") DYN_WITH_RE = re.compile(r"\{(?P[_a-zA-Z][_a-zA-Z0-9]*):(?P.+)\}") GOOD = r"[^{}/]+" @@ -974,7 +972,6 @@ def __contains__(self, route: object) -> bool: class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]): - NAME_SPLIT_RE = re.compile(r"[.:-]") def __init__(self) -> None: diff --git a/aiohttp/worker.py b/aiohttp/worker.py index c73178d314c..dcf147e5ac4 100644 --- a/aiohttp/worker.py +++ b/aiohttp/worker.py @@ -26,11 +26,10 @@ SSLContext = object # type: ignore[misc,assignment] -__all__ = ("GunicornWebWorker", "GunicornUVLoopWebWorker", "GunicornTokioWebWorker") +__all__ = ("GunicornWebWorker", "GunicornUVLoopWebWorker") class GunicornWebWorker(base.Worker): # type: ignore[misc,no-any-unimported] - DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default @@ -179,14 +178,8 @@ def init_signals(self) -> None: signal.siginterrupt(signal.SIGUSR1, False) # Reset signals so Gunicorn doesn't swallow subprocess return codes # See: https://github.com/aio-libs/aiohttp/issues/6130 - if sys.version_info < (3, 8): - # Starting from Python 3.8, - # the default child watcher is ThreadedChildWatcher. - # The watcher doesn't depend on SIGCHLD signal, - # there is no need to reset it. - signal.signal(signal.SIGCHLD, signal.SIG_DFL) - - def handle_quit(self, sig: int, frame: FrameType) -> None: + + def handle_quit(self, sig: int, frame: Optional[FrameType]) -> None: self.alive = False # worker_int callback @@ -195,7 +188,7 @@ def handle_quit(self, sig: int, frame: FrameType) -> None: # wakeup closing process self._notify_waiter_done() - def handle_abort(self, sig: int, frame: FrameType) -> None: + def handle_abort(self, sig: int, frame: Optional[FrameType]) -> None: self.alive = False self.exit_code = 1 self.cfg.worker_abort(self) @@ -244,15 +237,3 @@ def init_process(self) -> None: asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) super().init_process() - - -class GunicornTokioWebWorker(GunicornWebWorker): - def init_process(self) -> None: # pragma: no cover - import tokio - - # Setup tokio policy, so that every - # asyncio.get_event_loop() will create an instance - # of tokio event loop. - asyncio.set_event_loop_policy(tokio.EventLoopPolicy()) - - super().init_process() diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 5cb71a959d1..6200c79b7ad 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -480,26 +480,12 @@ checks can be relaxed by setting *ssl* to ``False``:: r = await session.get('https://example.com', ssl=False) - If you need to setup custom ssl parameters (use own certification files for example) you can create a :class:`ssl.SSLContext` instance and -pass it into the proper :class:`ClientSession` method:: - - sslcontext = ssl.create_default_context( - cafile='/path/to/ca-bundle.crt') - r = await session.get('https://example.com', ssl=sslcontext) - -If you need to verify *self-signed* certificates, you can do the -same thing as the previous example, but add another call to -:meth:`ssl.SSLContext.load_cert_chain` with the key pair:: - - sslcontext = ssl.create_default_context( - cafile='/path/to/ca-bundle.crt') - sslcontext.load_cert_chain('/path/to/client/public/device.pem', - '/path/to/client/private/device.key') - r = await session.get('https://example.com', ssl=sslcontext) +pass it into the :meth:`ClientSession.request` methods or set it for the +entire session with ``ClientSession(connector=TCPConnector(ssl=ssl_context))``. -There is explicit errors when ssl verification fails +There are explicit errors when ssl verification fails :class:`aiohttp.ClientConnectorSSLError`:: @@ -529,6 +515,34 @@ If you need to skip both ssl related errors except aiohttp.ClientSSLError as e: assert isinstance(e, ssl.CertificateError) +Example: Use certifi +^^^^^^^^^^^^^^^^^^^^ + +By default, Python uses the system CA certificates. In rare cases, these may not be +installed or Python is unable to find them, resulting in a error like +`ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate` + +One way to work around this problem is to use the `certifi` package:: + + ssl_context = ssl.create_default_context(cafile=certifi.where()) + async with ClientSession(connector=TCPConnector(ssl=ssl_context)) as sess: + ... + +Example: Use self-signed certificate +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +If you need to verify *self-signed* certificates, you need to add a call to +:meth:`ssl.SSLContext.load_cert_chain` with the key pair:: + + ssl_context = ssl.create_default_context() + ssl_context.load_cert_chain("/path/to/client/public/device.pem", + "/path/to/client/private/device.key") + async with sess.get("https://example.com", ssl=ssl_context) as resp: + ... + +Example: Verify certificate fingerprint +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + You may also verify certificates via *SHA256* fingerprint:: # Attempt to connect to https://www.python.org @@ -561,6 +575,8 @@ DER with e.g:: to :class:`TCPConnector` as default, the value from :meth:`ClientSession.get` and others override default. +.. _aiohttp-client-proxy-support: + Proxy support ------------- @@ -592,21 +608,24 @@ Authentication credentials can be passed in proxy URL:: Contrary to the ``requests`` library, it won't read environment variables by default. But you can do so by passing ``trust_env=True`` into :class:`aiohttp.ClientSession` -constructor for extracting proxy configuration from -*HTTP_PROXY*, *HTTPS_PROXY*, *WS_PROXY* or *WSS_PROXY* *environment -variables* (all are case insensitive):: +constructor.:: async with aiohttp.ClientSession(trust_env=True) as session: async with session.get("http://python.org") as resp: print(resp.status) +.. note:: + aiohttp uses :func:`urllib.request.getproxies` + for reading the proxy configuration (e.g. from the *HTTP_PROXY* etc. environment variables) and applies them for the *HTTP*, *HTTPS*, *WS* and *WSS* schemes. + + Hosts defined in ``no_proxy`` will bypass the proxy. + Proxy credentials are given from ``~/.netrc`` file if present (see :class:`aiohttp.ClientSession` for more details). .. attention:: - CPython has introduced the support for TLS in TLS around Python 3.7. - But, as of now (Python 3.10), it's disabled for the transports that + As of now (Python 3.10), support for TLS in TLS is disabled for the transports that :py:mod:`asyncio` uses. If the further release of Python (say v3.11) toggles one attribute, it'll *just work™*. diff --git a/docs/client_quickstart.rst b/docs/client_quickstart.rst index c7fec9f1936..92334a5f4b4 100644 --- a/docs/client_quickstart.rst +++ b/docs/client_quickstart.rst @@ -68,7 +68,7 @@ endpoints of ``http://httpbin.org`` can be used the following code:: .. note:: Don't create a session per request. Most likely you need a session - per application which performs all requests altogether. + per application which performs all requests together. More complex cases may require a session per site, e.g. one for Github and other one for Facebook APIs. Anyway making a session for @@ -316,7 +316,7 @@ To upload Multipart-encoded files:: You can set the ``filename`` and ``content_type`` explicitly:: url = 'http://httpbin.org/post' - data = FormData() + data = aiohttp.FormData() data.add_field('file', open('report.xls', 'rb'), filename='report.xls', diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 1ef29118e9a..407030dce14 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -32,8 +32,7 @@ Usage example:: html = await fetch(client) print(html) - loop = asyncio.get_event_loop() - loop.run_until_complete(main()) + asyncio.run(main()) The client session supports the context manager protocol for self closing. @@ -173,12 +172,17 @@ The client session supports the context manager protocol for self closing. .. versionadded:: 3.7 - :param bool trust_env: Get proxies information from *HTTP_PROXY* / - *HTTPS_PROXY* environment variables if the parameter is ``True`` - (``False`` by default). + :param bool trust_env: Trust environment settings for proxy configuration if the parameter + is ``True`` (``False`` by default). See :ref:`aiohttp-client-proxy-support` for + more information. Get proxy credentials from ``~/.netrc`` file if present. + Get HTTP Basic Auth credentials from :file:`~/.netrc` file if present. + + If :envvar:`NETRC` environment variable is set, read from file specified + there rather than from :file:`~/.netrc`. + .. seealso:: ``.netrc`` documentation: https://www.gnu.org/software/inetutils/manual/html_node/The-_002enetrc-file.html @@ -189,6 +193,10 @@ The client session supports the context manager protocol for self closing. Added support for ``~/.netrc`` file. + .. versionchanged:: 3.9 + + Added support for reading HTTP Basic Auth credentials from :file:`~/.netrc` file. + :param bool requote_redirect_url: Apply *URL requoting* for redirection URLs if automatic redirection is enabled (``True`` by default). @@ -312,8 +320,9 @@ The client session supports the context manager protocol for self closing. .. attribute:: trust_env - Should get proxies information from HTTP_PROXY / HTTPS_PROXY environment - variables or ~/.netrc file if present + Trust environment settings for proxy configuration + or ~/.netrc file if present. See :ref:`aiohttp-client-proxy-support` for + more information. :class:`bool` default is ``False`` @@ -347,7 +356,9 @@ The client session supports the context manager protocol for self closing. :param str method: HTTP method - :param url: Request URL, :class:`str` or :class:`~yarl.URL`. + :param url: Request URL, :class:`~yarl.URL` or :class:`str` that will + be encoded with :class:`~yarl.URL` (see :class:`~yarl.URL` + to skip encoding). :param params: Mapping, iterable of tuple of *key*/*value* pairs or string to be sent as parameters in the query @@ -666,7 +677,9 @@ The client session supports the context manager protocol for self closing. Create a websocket connection. Returns a :class:`ClientWebSocketResponse` object. - :param url: Websocket server url, :class:`str` or :class:`~yarl.URL` + :param url: Websocket server url, :class:`~yarl.URL` or :class:`str` that + will be encoded with :class:`~yarl.URL` (see :class:`~yarl.URL` + to skip encoding). :param tuple protocols: Websocket protocols @@ -831,7 +844,9 @@ certification chaining. :param str method: HTTP method - :param url: Requested URL, :class:`str` or :class:`~yarl.URL` + :param url: Request URL, :class:`~yarl.URL` or :class:`str` that will + be encoded with :class:`~yarl.URL` (see :class:`~yarl.URL` + to skip encoding). :param dict params: Parameters to be sent in the query string of the new request (optional) @@ -2085,7 +2100,7 @@ All exceptions are available as members of *aiohttp* module. Represent Content-Disposition header - .. attribute:: value + .. attribute:: type A :class:`str` instance. Value of Content-Disposition header itself, e.g. ``attachment``. diff --git a/docs/conf.py b/docs/conf.py index 7585dfa5206..f1b3b44c41b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -75,6 +75,7 @@ "aiohttpremotes": ("https://aiohttp-remotes.readthedocs.io/en/stable/", None), "aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None), "aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None), + "aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None), "asynctest": ("https://asynctest.readthedocs.io/en/latest/", None), } @@ -200,12 +201,6 @@ "height": "20", "alt": "Latest PyPI package version", }, - { - "image": f"https://img.shields.io/discourse/status?server=https%3A%2F%2F{github_repo_org}.discourse.group", - "target": f"https://{github_repo_org}.discourse.group", - "height": "20", - "alt": "Discourse status", - }, { "image": "https://badges.gitter.im/Join%20Chat.svg", "target": f"https://gitter.im/{github_repo_org}/Lobby", diff --git a/docs/contributing.rst b/docs/contributing.rst index 8476b7904f3..f0c65a9b815 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -110,6 +110,8 @@ Install pre-commit hooks: Congratulations, you are ready to run the test suite! +.. include:: ../vendor/README.rst + Run autoformatter ----------------- @@ -315,7 +317,7 @@ The rules for committers are simple: 4. Keep test suite comprehensive. In practice it means leveling up coverage. 97% is not bad but we wish to have 100% someday. Well, 99% is good target too. -5. Don't hesitate to improve our docs. Documentation is very important +5. Don't hesitate to improve our docs. Documentation is a very important thing, it's the key for project success. The documentation should not only cover our public API but help newbies to start using the project and shed a light on non-obvious gotchas. diff --git a/docs/glossary.rst b/docs/glossary.rst index 497f901176b..81bfcfa654b 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -89,6 +89,8 @@ It makes communication faster by getting rid of connection establishment for every request. + + nginx Nginx [engine x] is an HTTP and reverse proxy server, a mail @@ -153,3 +155,16 @@ A library for operating with URL objects. https://pypi.python.org/pypi/yarl + + +Environment Variables +===================== + +.. envvar:: NETRC + + If set, HTTP Basic Auth will be read from the file pointed to by this environment variable, + rather than from :file:`~/.netrc`. + + .. seealso:: + + ``.netrc`` documentation: https://www.gnu.org/software/inetutils/manual/html_node/The-_002enetrc-file.html diff --git a/docs/http_request_lifecycle.rst b/docs/http_request_lifecycle.rst index e14fb03de5f..22f6fbb8e30 100644 --- a/docs/http_request_lifecycle.rst +++ b/docs/http_request_lifecycle.rst @@ -77,8 +77,7 @@ So you are expected to reuse a session object and make many requests from it. Fo html = await response.text() print(html) - loop = asyncio.get_event_loop() - loop.run_until_complete(main()) + asyncio.run(main()) Can become this: @@ -98,8 +97,7 @@ Can become this: html = await fetch(session, 'http://python.org') print(html) - loop = asyncio.get_event_loop() - loop.run_until_complete(main()) + asyncio.run(main()) On more complex code bases, you can even create a central registry to hold the session object from anywhere in the code, or a higher level ``Client`` class that holds a reference to it. diff --git a/docs/index.rst b/docs/index.rst index b811fa5210b..a8b99234ae7 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -50,8 +50,8 @@ This option is highly recommended: $ pip install aiodns -Installing speedups altogether ------------------------------- +Installing all speedups in one command +-------------------------------------- The following will get you ``aiohttp`` along with :term:`cchardet`, :term:`aiodns` and ``Brotli`` in one bundle. No need to type @@ -83,8 +83,7 @@ Client example html = await response.text() print("Body:", html[:15], "...") - loop = asyncio.get_event_loop() - loop.run_until_complete(main()) + asyncio.run(main()) This prints: @@ -141,14 +140,10 @@ Please feel free to file an issue on the `bug tracker `_ if you have found a bug or have some suggestion in order to improve the library. -The library uses `Azure Pipelines `_ for -Continuous Integration. - Dependencies ============ -- Python 3.7+ - *async_timeout* - *charset-normalizer* - *multidict* @@ -179,7 +174,7 @@ Dependencies Communication channels ====================== -*aio-libs discourse group*: https://aio-libs.discourse.group +*aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions Feel free to post your questions and ideas here. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 4f96f987f3f..4c732ddd1d4 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -39,7 +39,6 @@ BodyPartReader boolean botocore brotli -brotli Brotli brotlipy bugfix @@ -102,6 +101,7 @@ dns DNSResolver docstring docstrings +DoS Dup elasticsearch encodings @@ -340,6 +340,7 @@ utils uvloop uWSGI vcvarsall +vendored waituntil wakeup wakeups diff --git a/docs/testing.rst b/docs/testing.rst index 961f78f523c..c94ad6708e6 100644 --- a/docs/testing.rst +++ b/docs/testing.rst @@ -318,9 +318,12 @@ functionality, the AioHTTPTestCase is provided:: A base class to allow for unittest web applications using aiohttp. - Derived from :class:`unittest.TestCase` + Derived from :class:`unittest.IsolatedAsyncioTestCase` - Provides the following: + See :class:`unittest.TestCase` and :class:`unittest.IsolatedAsyncioTestCase` + for inherited methods and behavior. + + This class additionally provides the following: .. attribute:: client @@ -332,12 +335,6 @@ functionality, the AioHTTPTestCase is provided:: .. versionadded:: 2.3 - .. attribute:: loop - - The event loop in which the application and server are running. - - .. deprecated:: 3.5 - .. attribute:: app The application returned by :meth:`~aiohttp.test_utils.AioHTTPTestCase.get_application` @@ -369,123 +366,38 @@ functionality, the AioHTTPTestCase is provided:: :return: :class:`aiohttp.web.Application` instance. - .. comethod:: setUpAsync() + .. comethod:: asyncSetUp() This async method can be overridden to execute asynchronous code during the ``setUp`` stage of the ``TestCase``:: - async def setUpAsync(self): - await super().setUpAsync() + async def asyncSetUp(self): + await super().asyncSetUp() await foo() .. versionadded:: 2.3 .. versionchanged:: 3.8 - ``await super().setUpAsync()`` call is required. + ``await super().asyncSetUp()`` call is required. - .. comethod:: tearDownAsync() + .. comethod:: asyncTearDown() This async method can be overridden to execute asynchronous code during the ``tearDown`` stage of the ``TestCase``:: - async def tearDownAsync(self): - await super().tearDownAsync() + async def asyncTearDown(self): + await super().asyncTearDown() await foo() .. versionadded:: 2.3 .. versionchanged:: 3.8 - ``await super().tearDownAsync()`` call is required. - - .. method:: setUp() - - Standard test initialization method. - - .. method:: tearDown() - - Standard test finalization method. - - - .. note:: - - The ``TestClient``'s methods are asynchronous: you have to - execute functions on the test client using asynchronous methods.:: - - class TestA(AioHTTPTestCase): - - async def test_f(self): - async with self.client.get('/') as resp: - body = await resp.text() - -Patching unittest test cases -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Patching test cases is tricky, when using python older than 3.8 :py:func:`~unittest.mock.patch` does not behave as it has to. -We recommend using :py:mod:`asynctest` that provides :py:func:`~asynctest.patch` that is capable of creating -a magic mock that supports async. It can be used with a decorator as well as with a context manager: - -.. code-block:: python - :emphasize-lines: 1,37,46 - - from asynctest.mock import patch as async_patch - - from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop - from aiohttp.web_app import Application - from aiohttp.web_request import Request - from aiohttp.web_response import Response - from aiohttp.web_routedef import get - - - async def do_something(): - print('something') - - - async def ping(request: Request) -> Response: - await do_something() - return Response(text='pong') - - - class TestApplication(AioHTTPTestCase): - def get_app(self) -> Application: - app = Application() - app.router.add_routes([ - get('/ping/', ping) - ]) - - return app - - @unittest_run_loop - async def test_ping(self): - resp = await self.client.get('/ping/') - - self.assertEqual(resp.status, 200) - self.assertEqual(await resp.text(), 'pong') - - @unittest_run_loop - async def test_ping_mocked_do_something(self): - with async_patch('tests.do_something') as do_something_patch: - resp = await self.client.get('/ping/') - - self.assertEqual(resp.status, 200) - self.assertEqual(await resp.text(), 'pong') - - self.assertTrue(do_something_patch.called) - - @unittest_run_loop - @async_patch('tests.do_something') - async def test_ping_mocked_do_something_decorated(self, do_something_patch): - resp = await self.client.get('/ping/') - - self.assertEqual(resp.status, 200) - self.assertEqual(await resp.text(), 'pong') - - self.assertTrue(do_something_patch.called) - + ``await super().asyncTearDown()`` call is required. Faking request object ---------------------- +^^^^^^^^^^^^^^^^^^^^^ aiohttp provides test utility for creating fake :class:`aiohttp.web.Request` objects: @@ -543,7 +455,7 @@ conditions that hard to reproduce on real server:: :param headers: mapping containing the headers. Can be anything accepted by the multidict.CIMultiDict constructor. - :type headers: dict, multidict.CIMultiDict, list of pairs + :type headers: dict, multidict.CIMultiDict, list of tuple(str, str) :param match_info: mapping containing the info to match with url parameters. :type match_info: dict @@ -584,7 +496,7 @@ conditions that hard to reproduce on real server:: Framework Agnostic Utilities -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +---------------------------- High level test creation:: diff --git a/docs/third_party.rst b/docs/third_party.rst index f73d819bbaa..9dca72369dd 100644 --- a/docs/third_party.rst +++ b/docs/third_party.rst @@ -112,8 +112,7 @@ support to aiohttp web servers. - `aiohttp-pydantic `_ An ``aiohttp.View`` to validate the HTTP request's body, query-string, and - headers regarding function annotations and generate OpenAPI doc. Python 3.8+ - required. + headers regarding function annotations and generate OpenAPI doc. - `aiohttp-swagger `_ Swagger API Documentation builder for aiohttp server. @@ -258,6 +257,8 @@ ask to raise the status. - `GINO `_ An asyncio ORM on top of SQLAlchemy core, delivered with an aiohttp extension. +- `New Relic `_ An aiohttp middleware for reporting your `Python application performance `_ metrics to New Relic. + - `eider-py `_ Python implementation of the `Eider RPC protocol `_. diff --git a/docs/web_advanced.rst b/docs/web_advanced.rst index 220a4d8e7df..6055ddaf319 100644 --- a/docs/web_advanced.rst +++ b/docs/web_advanced.rst @@ -19,14 +19,25 @@ But in case of custom regular expressions for *percent encoded*: if you pass Unicode patterns they don't match to *requoted* path. +.. _aiohttp-web-peer-disconnection: + Peer disconnection ------------------ -When a client peer is gone a subsequent reading or writing raises :exc:`OSError` -or more specific exception like :exc:`ConnectionResetError`. +*aiohttp* has 2 approaches to handling client disconnections. +If you are familiar with asyncio, or scalability is a concern for +your application, we recommend using the handler cancellation method. + +Raise on read/write (default) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When a client peer is gone, a subsequent reading or writing raises :exc:`OSError` +or a more specific exception like :exc:`ConnectionResetError`. -The reason for disconnection is vary; it can be a network issue or explicit -socket closing on the peer side without reading the whole server response. +This behavior is similar to classic WSGI frameworks like Flask and Django. + +The reason for disconnection varies; it can be a network issue or explicit +socket closing on the peer side without reading the full server response. *aiohttp* handles disconnection properly but you can handle it explicitly, e.g.:: @@ -36,6 +47,122 @@ socket closing on the peer side without reading the whole server response. except OSError: # disconnected +Web handler cancellation +^^^^^^^^^^^^^^^^^^^^^^^^ + +This method can be enabled using the ``handler_cancellation`` parameter +to :func:`run_app`. + +When a client disconnects, the web handler task will be cancelled. This +is recommended as it can reduce the load on your server when there is no +client to receive a response. It can also help make your application +more resilient to DoS attacks (by requiring an attacker to keep a +connection open in order to waste server resources). + +This behavior is very different from classic WSGI frameworks like +Flask and Django. It requires a reasonable level of asyncio knowledge to +use correctly without causing issues in your code. We provide some +examples here to help understand the complexity and methods +needed to deal with them. + +.. warning:: + + :term:`web-handler` execution could be canceled on every ``await`` + if client drops connection without reading entire response's BODY. + +Sometimes it is a desirable behavior: on processing ``GET`` request the +code might fetch data from a database or other web resource, the +fetching is potentially slow. + +Canceling this fetch is a good idea: the peer dropped connection +already, so there is no reason to waste time and resources (memory etc) +by getting data from a DB without any chance to send it back to peer. + +But sometimes the cancellation is bad: on ``POST`` request very often +it is needed to save data to a DB regardless of peer closing. + +Cancellation prevention could be implemented in several ways: + +* Applying :func:`asyncio.shield` to a coroutine that saves data. +* Using aiojobs_ or another third party library. + +:func:`asyncio.shield` can work well. The only disadvantage is you +need to split web handler into exactly two async functions: one +for handler itself and other for protected code. + +For example the following snippet is not safe:: + + async def handler(request): + await asyncio.shield(write_to_redis(request)) + await asyncio.shield(write_to_postgres(request)) + return web.Response(text="OK") + +Cancellation might occur while saving data in REDIS, so +``write_to_postgres`` will not be called, potentially +leaving your data in an inconsistent state. + +Instead, you would need to write something like:: + + async def write_data(request): + await write_to_redis(request) + await write_to_postgres(request) + + async def handler(request): + await asyncio.shield(write_data(request)) + return web.Response(text="OK") + +Alternatively, if you want to spawn a task without waiting for +its completion, you can use aiojobs_ which provides an API for +spawning new background jobs. It stores all scheduled activity in +internal data structures and can terminate them gracefully:: + + from aiojobs.aiohttp import setup, spawn + + async def handler(request): + await spawn(request, write_data()) + return web.Response() + + app = web.Application() + setup(app) + app.router.add_get("/", handler) + +.. warning:: + + Don't use :func:`asyncio.create_task` for this. All tasks + should be awaited at some point in your code (``aiojobs`` handles + this for you), otherwise you will hide legitimate exceptions + and result in warnings being emitted. + + A good case for using :func:`asyncio.create_task` is when + you want to run something while you are processing other data, + but still want to ensure the task is complete before returning:: + + async def handler(request): + t = asyncio.create_task(get_some_data()) + ... # Do some other things, while data is being fetched. + data = await t + return web.Response(text=data) + +One more approach would be to use :func:`aiojobs.aiohttp.atomic` +decorator to execute the entire handler as a new job. Essentially +restoring the default disconnection behavior only for specific handlers:: + + from aiojobs.aiohttp import atomic + + @atomic + async def handler(request): + await write_to_db() + return web.Response() + + app = web.Application() + setup(app) + app.router.add_post("/", handler) + +It prevents all of the ``handler`` async function from cancellation, +so ``write_to_db`` will be never interrupted. + +.. _aiojobs: http://aiojobs.readthedocs.io/en/latest/ + Passing a coroutine into run_app and Gunicorn --------------------------------------------- @@ -800,8 +927,14 @@ Graceful shutdown Stopping *aiohttp web server* by just closing all connections is not always satisfactory. -The problem is: if application supports :term:`websocket`\s or *data -streaming* it most likely has open connections at server +The first thing aiohttp will do is to stop listening on the sockets, +so new connections will be rejected. It will then wait a few +seconds to allow any pending tasks to complete before continuing +with application shutdown. The timeout can be adjusted with +``shutdown_timeout`` in :func:`run_app`. + +Another problem is if the application supports :term:`websockets ` or +*data streaming* it most likely has open connections at server shutdown time. The *library* has no knowledge how to close them gracefully but diff --git a/docs/web_lowlevel.rst b/docs/web_lowlevel.rst index 456b8fea4cd..6d76ca3b95c 100644 --- a/docs/web_lowlevel.rst +++ b/docs/web_lowlevel.rst @@ -69,13 +69,7 @@ The following code demonstrates very trivial usage example:: await asyncio.sleep(100*3600) - loop = asyncio.get_event_loop() - - try: - loop.run_until_complete(main()) - except KeyboardInterrupt: - pass - loop.close() + asyncio.run(main()) In the snippet we have ``handler`` which returns a regular diff --git a/docs/web_quickstart.rst b/docs/web_quickstart.rst index a9e080776df..7bfb311e670 100644 --- a/docs/web_quickstart.rst +++ b/docs/web_quickstart.rst @@ -148,6 +148,12 @@ for a ``GET`` request. You can also deny ``HEAD`` requests on a route:: Here ``handler`` won't be called on ``HEAD`` request and the server will respond with ``405: Method Not Allowed``. +.. seealso:: + + :ref:`aiohttp-web-peer-disconnection` section explains how handlers + behave when a client connection drops and ways to optimize handling + of this. + .. _aiohttp-web-resource-and-route: Resources and Routes diff --git a/docs/web_reference.rst b/docs/web_reference.rst index f79dd050c3b..baf36a80682 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -900,30 +900,15 @@ Response Read-write attribute for storing response's content aka BODY, :class:`bytes`. - Setting :attr:`body` also recalculates - :attr:`~StreamResponse.content_length` value. - Assigning :class:`str` to :attr:`body` will make the :attr:`body` type of :class:`aiohttp.payload.StringPayload`, which tries to encode the given data based on *Content-Type* HTTP header, while defaulting to ``UTF-8``. - Resetting :attr:`body` (assigning ``None``) sets - :attr:`~StreamResponse.content_length` to ``None`` too, dropping - *Content-Length* HTTP header. - .. attribute:: text - Read-write attribute for storing response's content, represented as - string, :class:`str`. - - Setting :attr:`text` also recalculates - :attr:`~StreamResponse.content_length` value and - :attr:`~aiohttp.StreamResponse.body` value - - Resetting :attr:`text` (assigning ``None``) sets - :attr:`~StreamResponse.content_length` to ``None`` too, dropping - *Content-Length* HTTP header. + Read-write attribute for storing response's + :attr:`~aiohttp.StreamResponse.body`, represented as :class:`str`. FileResponse @@ -1392,6 +1377,13 @@ duplicated like one using :meth:`~aiohttp.web.Application.copy`. async def on_prepare(request, response): pass + .. note:: + + The headers are written immediately after these callbacks are run. + Therefore, if you modify the content of the response, you may need to + adjust the `Content-Length` header or similar to match. Aiohttp will + not make any updates to the headers from this point. + .. attribute:: on_startup A :class:`~aiosignal.Signal` that is fired on application start-up. @@ -2688,9 +2680,10 @@ application on specific TCP or Unix socket, e.g.:: :param int port: PORT to listed on, ``8080`` if ``None`` (default). - :param float shutdown_timeout: a timeout for closing opened - connections on :meth:`BaseSite.stop` - call. + :param float shutdown_timeout: a timeout used for both waiting on pending + tasks before application shutdown and for + closing opened connections on + :meth:`BaseSite.stop` call. :param ssl_context: a :class:`ssl.SSLContext` instance for serving SSL/TLS secure server, ``None`` for plain HTTP @@ -2723,9 +2716,10 @@ application on specific TCP or Unix socket, e.g.:: :param str path: PATH to UNIX socket to listen. - :param float shutdown_timeout: a timeout for closing opened - connections on :meth:`BaseSite.stop` - call. + :param float shutdown_timeout: a timeout used for both waiting on pending + tasks before application shutdown and for + closing opened connections on + :meth:`BaseSite.stop` call. :param ssl_context: a :class:`ssl.SSLContext` instance for serving SSL/TLS secure server, ``None`` for plain HTTP @@ -2745,9 +2739,10 @@ application on specific TCP or Unix socket, e.g.:: :param str path: PATH of named pipe to listen. - :param float shutdown_timeout: a timeout for closing opened - connections on :meth:`BaseSite.stop` - call. + :param float shutdown_timeout: a timeout used for both waiting on pending + tasks before application shutdown and for + closing opened connections on + :meth:`BaseSite.stop` call. .. class:: SockSite(runner, sock, *, \ shutdown_timeout=60.0, ssl_context=None, \ @@ -2759,9 +2754,10 @@ application on specific TCP or Unix socket, e.g.:: :param sock: A :ref:`socket instance ` to listen to. - :param float shutdown_timeout: a timeout for closing opened - connections on :meth:`BaseSite.stop` - call. + :param float shutdown_timeout: a timeout used for both waiting on pending + tasks before application shutdown and for + closing opened connections on + :meth:`BaseSite.stop` call. :param ssl_context: a :class:`ssl.SSLContext` instance for serving SSL/TLS secure server, ``None`` for plain HTTP @@ -2809,7 +2805,8 @@ Utilities access_log=aiohttp.log.access_logger, \ handle_signals=True, \ reuse_address=None, \ - reuse_port=None) + reuse_port=None, \ + handler_cancellation=False) A high-level function for running an application, serving it until keyboard interrupt and performing a @@ -2856,9 +2853,13 @@ Utilities shutdown before disconnecting all open client sockets hard way. + This is used as a delay to wait for + pending tasks to complete and then + again to close any pending connections. + A system with properly :ref:`aiohttp-web-graceful-shutdown` - implemented never waits for this + implemented never waits for the second timeout but closes a server in a few milliseconds. @@ -2905,6 +2906,12 @@ Utilities this flag when being created. This option is not supported on Windows. + :param bool handler_cancellation: cancels the web handler task if the client + drops the connection. This is recommended + if familiar with asyncio behavior or + scalability is a concern. + :ref:`aiohttp-web-peer-disconnection` + .. versionadded:: 3.0 Support *access_log_class* parameter. @@ -2915,6 +2922,11 @@ Utilities Accept a coroutine as *app* parameter. + .. versionadded:: 3.9 + + Support handler_cancellation parameter (this was the default behavior + in aiohttp <3.7). + Constants --------- diff --git a/requirements/base.txt b/requirements/base.txt index f5013b23a3b..e2ac167556e 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -4,11 +4,10 @@ aiodns==3.0.0; sys_platform=="linux" or sys_platform=="darwin" aiosignal==1.2.0 async-timeout==4.0.2 -asynctest==0.13.0; python_version<"3.8" Brotli==1.0.9 cchardet==2.1.7; python_version < "3.10" # Unmaintained: aio-libs/aiohttp#6819 charset-normalizer==2.0.12 -frozenlist==1.3.1 +frozenlist==1.4.0 gunicorn==20.1.0 -uvloop==0.14.0; platform_system!="Windows" and implementation_name=="cpython" and python_version<"3.9" # MagicStack/uvloop#14 -yarl==1.8.1 +uvloop==0.17.0; platform_system!="Windows" and implementation_name=="cpython" and python_version<"3.9" # MagicStack/uvloop#14 +yarl==1.9.2 diff --git a/requirements/constraints.txt b/requirements/constraints.txt index ac00eaa740e..4eec2d51418 100644 --- a/requirements/constraints.txt +++ b/requirements/constraints.txt @@ -51,7 +51,7 @@ click==8.0.3 # wait-for-it click-default-group==1.2.2 # via towncrier -coverage==6.4.2 +coverage==6.5.0 # via # -r requirements/test.txt # pytest-cov @@ -60,7 +60,7 @@ cryptography==36.0.1 ; platform_machine != "i686" # -r requirements/test.txt # pyjwt # trustme -cython==0.29.32 +cython==3.0.0 # via -r requirements/cython.txt distlib==0.3.3 # via virtualenv @@ -70,7 +70,7 @@ filelock==3.3.2 # via virtualenv freezegun==1.1.0 # via -r requirements/test.txt -frozenlist==1.3.1 +frozenlist==1.4.0 # via # -r requirements/base.txt # aiosignal @@ -89,7 +89,7 @@ idna==3.3 # requests # trustme # yarl -imagesize==1.2.0 +imagesize==1.4.1 # via sphinx importlib-metadata==4.12.0 # via sphinx @@ -111,10 +111,6 @@ mypy==0.982 ; implementation_name == "cpython" # via # -r requirements/lint.txt # -r requirements/test.txt -mypy-extensions==0.4.3 ; implementation_name == "cpython" - # via - # -r requirements/test.txt - # mypy nodeenv==1.6.0 # via pre-commit packaging==21.2 @@ -137,7 +133,7 @@ proxy-py==2.4.4rc3 # via -r requirements/test.txt py==1.11.0 # via pytest -pycares==4.1.2 +pycares==4.3.0 # via aiodns pycparser==2.20 # via cffi @@ -145,21 +141,21 @@ pydantic==1.8.2 # via python-on-whales pyenchant==3.2.2 # via sphinxcontrib-spelling -pygments==2.11.0 +pygments==2.13.0 # via sphinx pyjwt==2.3.0 # via gidgethub pyparsing==2.4.7 # via packaging -pytest==6.2.5 +pytest==7.4.0 # via # -r requirements/lint.txt # -r requirements/test.txt # pytest-cov # pytest-mock -pytest-cov==3.0.0 +pytest-cov==4.1.0 # via -r requirements/test.txt -pytest-mock==3.6.1 +pytest-mock==3.11.1 # via -r requirements/test.txt python-dateutil==2.8.2 # via freezegun @@ -188,7 +184,7 @@ slotscheck==0.8.0 # via -r requirements/lint.txt snowballstemmer==2.1.0 # via sphinx -sphinx==5.1.1 +sphinx==5.3.0 # via # -r requirements/doc.txt # sphinxcontrib-asyncio @@ -247,7 +243,7 @@ uritemplate==4.1.1 # via gidgethub urllib3==1.26.7 # via requests -uvloop==0.14.0 ; platform_system != "Windows" and implementation_name == "cpython" and python_version < "3.9" +uvloop==0.17.0 ; platform_system != "Windows" and implementation_name == "cpython" and python_version < "3.9" # via -r requirements/base.txt virtualenv==20.10.0 # via pre-commit @@ -257,7 +253,7 @@ webcolors==1.11.1 # via blockdiag wheel==0.37.0 # via pip-tools -yarl==1.8.1 +yarl==1.9.2 # via -r requirements/base.txt zipp==3.8.1 # via importlib-metadata diff --git a/requirements/cython.txt b/requirements/cython.txt index 4a10d5fd4f7..3f996dbf400 100644 --- a/requirements/cython.txt +++ b/requirements/cython.txt @@ -1,3 +1,3 @@ -r multidict.txt -r typing-extensions.txt # required for parsing aiohttp/hdrs.py by tools/gen.py -cython==0.29.32 +cython==3.0.0 diff --git a/requirements/doc.txt b/requirements/doc.txt index 586099a15e9..46c148fbf40 100644 --- a/requirements/doc.txt +++ b/requirements/doc.txt @@ -1,7 +1,7 @@ aiohttp-theme==0.1.6 # Temp fix till updated: https://github.com/blockdiag/blockdiag/pull/148 funcparserlib==1.0.0a0 -sphinx==5.1.1 +sphinx==5.3.0 sphinxcontrib-asyncio==0.3.0 sphinxcontrib-blockdiag==3.0.0 sphinxcontrib-towncrier==0.3.0a0 diff --git a/requirements/lint.txt b/requirements/lint.txt index 1bab0a04bd7..087cb8346d3 100644 --- a/requirements/lint.txt +++ b/requirements/lint.txt @@ -2,5 +2,6 @@ aioredis==2.0.1 mypy==0.982; implementation_name=="cpython" pre-commit==2.17.0 -pytest==6.2.5 +pytest==7.4.0 slotscheck==0.8.0 +uvloop==0.17.0; platform_system!="Windows" diff --git a/requirements/test.txt b/requirements/test.txt index 1ab1c614811..1584a3dd4bd 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,14 +1,13 @@ -r base.txt Brotli==1.0.9 -coverage==6.4.2 +coverage==6.5.0 cryptography==36.0.1; platform_machine!="i686" # no 32-bit wheels; no python 3.9 wheels yet freezegun==1.1.0 mypy==0.982; implementation_name=="cpython" -mypy-extensions==0.4.3; implementation_name=="cpython" proxy.py ~= 2.4.4rc3 -pytest==6.2.5 -pytest-cov==3.0.0 -pytest-mock==3.6.1 +pytest==7.4.0 +pytest-cov==4.1.0 +pytest-mock==3.11.1 python-on-whales==0.36.1 re-assert==1.1.0 setuptools-git==1.2 diff --git a/setup.cfg b/setup.cfg index d270e811d97..cc2b8b40c91 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,8 @@ name = aiohttp version = attr: aiohttp.__version__ url = https://github.com/aio-libs/aiohttp project_urls = - Chat: Gitter = https://gitter.im/aio-libs/Lobby + Chat: Matrix = https://matrix.to/#/#aio-libs:matrix.org + Chat: Matrix Space = https://matrix.to/#/#aio-libs-space:matrix.org CI: GitHub Actions = https://github.com/aio-libs/aiohttp/actions?query=workflow%%3ACI Coverage: codecov = https://codecov.io/github/aio-libs/aiohttp Docs: Changelog = https://docs.aiohttp.org/en/stable/changes.html @@ -32,7 +33,6 @@ classifiers = Programming Language :: Python Programming Language :: Python :: 3 - Programming Language :: Python :: 3.7 Programming Language :: Python :: 3.8 Programming Language :: Python :: 3.9 Programming Language :: Python :: 3.10 @@ -40,7 +40,7 @@ classifiers = Topic :: Internet :: WWW/HTTP [options] -python_requires = >=3.7 +python_requires = >=3.8 packages = find: # https://setuptools.readthedocs.io/en/latest/setuptools.html#setting-the-zip-safe-flag zip_safe = False @@ -50,9 +50,7 @@ install_requires = charset-normalizer >=2.0, < 4.0 multidict >=4.5, < 7.0 async_timeout >= 4.0, < 5.0 - asynctest == 0.13.0; python_version<"3.8" yarl >= 1.0, < 2.0 - typing_extensions >= 3.7.4 frozenlist >= 1.1.1 aiosignal >= 1.1.2 @@ -88,6 +86,14 @@ zip_ok = false # TODO: don't disable D*, fix up issues instead ignore = N801,N802,N803,E203,E226,E305,W504,E252,E301,E302,E704,W503,W504,F811,D1,D4 max-line-length = 88 +per-file-ignores = + # I900: Shouldn't appear in requirements for examples. + examples/*:I900 + +# flake8-requirements +known-modules = proxy.py:[proxy] +requirements-file = requirements/test.txt +requirements-max-depth = 4 [isort] line_length=88 diff --git a/setup.py b/setup.py index b27a54d614b..2622e5c5223 100644 --- a/setup.py +++ b/setup.py @@ -4,8 +4,8 @@ from setuptools import Extension, setup -if sys.version_info < (3, 7): - raise RuntimeError("aiohttp 4.x requires Python 3.7+") +if sys.version_info < (3, 8): + raise RuntimeError("aiohttp 4.x requires Python 3.8+") NO_EXTENSIONS: bool = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/autobahn/test_autobahn.py b/tests/autobahn/test_autobahn.py index 5d72e37a17a..32cdacb5fa4 100644 --- a/tests/autobahn/test_autobahn.py +++ b/tests/autobahn/test_autobahn.py @@ -16,7 +16,6 @@ def report_dir(tmp_path_factory: TempPathFactory) -> Path: @pytest.fixture(scope="session", autouse=True) def build_autobahn_testsuite() -> Generator[None, None, None]: - try: docker.build( file="tests/autobahn/Dockerfile.autobahn", diff --git a/tests/conftest.py b/tests/conftest.py index 6e0cf73f93c..c2696ef2f7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,9 +15,6 @@ from aiohttp.test_utils import loop_context -IS_LINUX: bool -IS_UNIX: bool -needs_unix: bool try: import trustme @@ -30,14 +27,8 @@ pytest_plugins: List[str] = ["aiohttp.pytest_plugin", "pytester"] -IS_HPUX: bool = sys.platform.startswith("hp-ux") -# Specifies whether the current runtime is HP-UX. +IS_HPUX = sys.platform.startswith("hp-ux") IS_LINUX = sys.platform.startswith("linux") -# Specifies whether the current runtime is HP-UX. -IS_UNIX = hasattr(socket, "AF_UNIX") -# Specifies whether the current runtime is *NIX. - -needs_unix = pytest.mark.skipif(not IS_UNIX, reason="requires UNIX sockets") @pytest.fixture @@ -120,7 +111,7 @@ def unix_sockname(tmp_path: Any, tmp_path_factory: Any): # mostly 104 but sometimes it can be down to 100. # Ref: https://github.com/aio-libs/aiohttp/issues/3572 - if not IS_UNIX: + if not hasattr(socket, "AF_UNIX"): pytest.skip("requires UNIX sockets") max_sock_len = 92 if IS_HPUX else 108 if IS_LINUX else 100 @@ -188,12 +179,31 @@ def assert_sock_fits(sock_path): @pytest.fixture def selector_loop() -> None: - if sys.version_info >= (3, 8): - policy = asyncio.WindowsSelectorEventLoopPolicy() - else: - policy = asyncio.DefaultEventLoopPolicy() + policy = asyncio.WindowsSelectorEventLoopPolicy() asyncio.set_event_loop_policy(policy) with loop_context(policy.new_event_loop) as _loop: asyncio.set_event_loop(_loop) yield _loop + + +@pytest.fixture +def netrc_contents( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, + request: pytest.FixtureRequest, +): + """ + Prepare :file:`.netrc` with given contents. + + Monkey-patches :envvar:`NETRC` to point to created file. + """ + netrc_contents = getattr(request, "param", None) + + netrc_file_path = tmp_path / ".netrc" + if netrc_contents is not None: + netrc_file_path.write_text(netrc_contents) + + monkeypatch.setenv("NETRC", str(netrc_file_path)) + + return netrc_file_path diff --git a/tests/test_base_protocol.py b/tests/test_base_protocol.py index ff7fd0be523..7254c6ff8cb 100644 --- a/tests/test_base_protocol.py +++ b/tests/test_base_protocol.py @@ -22,6 +22,42 @@ async def test_pause_writing() -> None: assert pr._paused +async def test_pause_reading_no_transport() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop) + assert not pr._reading_paused + pr.pause_reading() + assert not pr._reading_paused + + +async def test_pause_reading_stub_transport() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop) + tr = asyncio.Transport() + pr.transport = tr + assert not pr._reading_paused + pr.pause_reading() + assert pr._reading_paused + + +async def test_resume_reading_no_transport() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop) + pr._reading_paused = True + pr.resume_reading() + assert pr._reading_paused + + +async def test_resume_reading_stub_transport() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop) + tr = asyncio.Transport() + pr.transport = tr + pr._reading_paused = True + pr.resume_reading() + assert not pr._reading_paused + + async def test_resume_writing_no_waiters() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) @@ -31,6 +67,17 @@ async def test_resume_writing_no_waiters() -> None: assert not pr._paused +async def test_resume_writing_waiter_done() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + waiter = mock.Mock(done=mock.Mock(return_value=True)) + pr._drain_waiter = waiter + pr._paused = True + pr.resume_writing() + assert not pr._paused + assert waiter.mock_calls == [mock.call.done()] + + async def test_connection_made() -> None: loop = asyncio.get_event_loop() pr = BaseProtocol(loop=loop) @@ -45,10 +92,10 @@ async def test_connection_lost_not_paused() -> None: pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) - assert not pr._connection_lost + assert pr.connected pr.connection_lost(None) assert pr.transport is None - assert pr._connection_lost + assert not pr.connected async def test_connection_lost_paused_without_waiter() -> None: @@ -56,11 +103,22 @@ async def test_connection_lost_paused_without_waiter() -> None: pr = BaseProtocol(loop=loop) tr = mock.Mock() pr.connection_made(tr) - assert not pr._connection_lost + assert pr.connected pr.pause_writing() pr.connection_lost(None) assert pr.transport is None - assert pr._connection_lost + assert not pr.connected + + +async def test_connection_lost_waiter_done() -> None: + loop = asyncio.get_event_loop() + pr = BaseProtocol(loop=loop) + pr._paused = True + waiter = mock.Mock(done=mock.Mock(return_value=True)) + pr._drain_waiter = waiter + pr.connection_lost(None) + assert pr._drain_waiter is None + assert waiter.mock_calls == [mock.call.done()] async def test_drain_lost() -> None: diff --git a/tests/test_circular_imports.py b/tests/test_circular_imports.py index daac4c9d395..a3d09811000 100644 --- a/tests/test_circular_imports.py +++ b/tests/test_circular_imports.py @@ -10,6 +10,7 @@ """ # noqa: E501 import os import pkgutil +import socket import subprocess import sys from itertools import chain @@ -22,8 +23,6 @@ if TYPE_CHECKING: from _pytest.mark.structures import ParameterSet -from conftest import IS_UNIX # type: ignore[attr-defined] - import aiohttp @@ -33,7 +32,9 @@ def _mark_aiohttp_worker_for_skipping( return [ pytest.param( importable, - marks=pytest.mark.skipif(not IS_UNIX, reason="It's a UNIX-only module"), + marks=pytest.mark.skipif( + not hasattr(socket, "AF_UNIX"), reason="It's a UNIX-only module" + ), ) if importable == "aiohttp.worker" else importable diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index 4087670a2a0..4edc00483cb 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -9,7 +9,7 @@ import pathlib import socket import ssl -from typing import Any +from typing import Any, AsyncIterator from unittest import mock import pytest @@ -673,6 +673,25 @@ async def handler(request): await resp.content.read() +async def test_read_timeout_on_write(aiohttp_client: Any) -> None: + async def gen_payload() -> AsyncIterator[str]: + # Delay writing to ensure read timeout isn't triggered before writing completes. + await asyncio.sleep(0.5) + yield b"foo" + + async def handler(request: web.Request) -> web.Response: + return web.Response(body=await request.read()) + + app = web.Application() + app.router.add_put("/", handler) + + timeout = aiohttp.ClientTimeout(total=None, sock_read=0.1) + client = await aiohttp_client(app) + async with client.put("/", data=gen_payload(), timeout=timeout) as resp: + result = await resp.read() # Should not trigger a read timeout. + assert result == b"foo" + + async def test_timeout_on_reading_data(aiohttp_client: Any, mocker: Any) -> None: loop = asyncio.get_event_loop() @@ -2578,7 +2597,6 @@ async def handler(request): async def test_aiohttp_request_ctx_manager_not_found() -> None: - with pytest.raises(aiohttp.ClientConnectionError): async with aiohttp.request("GET", "http://wrong-dns-name.com"): assert False, "never executed" # pragma: no cover @@ -3004,6 +3022,30 @@ async def handler(request): await resp.read() +async def test_timeout_with_full_buffer(aiohttp_client: Any) -> None: + async def handler(request): + """Server response that never ends and always has more data available.""" + resp = web.StreamResponse() + await resp.prepare(request) + while True: + await resp.write(b"1" * 1000) + await asyncio.sleep(0.01) + + async def request(client): + timeout = aiohttp.ClientTimeout(total=0.5) + async with await client.get("/", timeout=timeout) as resp: + with pytest.raises(asyncio.TimeoutError): + async for data in resp.content.iter_chunked(1): + await asyncio.sleep(0.01) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app) + # wait_for() used just to ensure that a failing test doesn't hang. + await asyncio.wait_for(request(client), 1) + + async def test_read_bufsize_session_default(aiohttp_client: Any) -> None: async def handler(request): return web.Response(body=b"1234567") @@ -3047,3 +3089,105 @@ async def handler(request): assert resp.status == 200 assert await resp.text() == "ok" assert resp.headers["Content-Type"] == "text/plain; charset=utf-8" + + +async def test_max_field_size_session_default(aiohttp_client: Any) -> None: + async def handler(request): + return web.Response(headers={"Custom": "x" * 8190}) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app) + + async with await client.get("/") as resp: + assert resp.headers["Custom"] == "x" * 8190 + + +async def test_max_field_size_session_default_fail(aiohttp_client: Any) -> None: + async def handler(request): + return web.Response(headers={"Custom": "x" * 8191}) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app) + with pytest.raises(aiohttp.ClientResponseError): + await client.get("/") + + +async def test_max_field_size_session_explicit(aiohttp_client: Any) -> None: + async def handler(request): + return web.Response(headers={"Custom": "x" * 8191}) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app, max_field_size=8191) + + async with await client.get("/") as resp: + assert resp.headers["Custom"] == "x" * 8191 + + +async def test_max_field_size_request_explicit(aiohttp_client: Any) -> None: + async def handler(request): + return web.Response(headers={"Custom": "x" * 8191}) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app) + + async with await client.get("/", max_field_size=8191) as resp: + assert resp.headers["Custom"] == "x" * 8191 + + +async def test_max_line_size_session_default(aiohttp_client: Any) -> None: + async def handler(request): + return web.Response(status=200, reason="x" * 8190) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app) + + async with await client.get("/") as resp: + assert resp.reason == "x" * 8190 + + +async def test_max_line_size_session_default_fail(aiohttp_client: Any) -> None: + async def handler(request): + return web.Response(status=200, reason="x" * 8192) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app) + with pytest.raises(aiohttp.ClientResponseError): + await client.get("/") + + +async def test_max_line_size_session_explicit(aiohttp_client: Any) -> None: + async def handler(request): + return web.Response(status=200, reason="x" * 8191) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app, max_line_size=8191) + + async with await client.get("/") as resp: + assert resp.reason == "x" * 8191 + + +async def test_max_line_size_request_explicit(aiohttp_client: Any) -> None: + async def handler(request): + return web.Response(status=200, reason="x" * 8191) + + app = web.Application() + app.add_routes([web.get("/", handler)]) + + client = await aiohttp_client(app) + + async with await client.get("/", max_line_size=8191) as resp: + assert resp.reason == "x" * 8191 diff --git a/tests/test_client_proto.py b/tests/test_client_proto.py index 08ae367e81b..e9e6e53a166 100644 --- a/tests/test_client_proto.py +++ b/tests/test_client_proto.py @@ -109,12 +109,15 @@ async def test_empty_data(loop: Any) -> None: async def test_schedule_timeout(loop: Any) -> None: proto = ResponseHandler(loop=loop) proto.set_response_params(read_timeout=1) + assert proto._read_timeout_handle is None + proto.start_timeout() assert proto._read_timeout_handle is not None async def test_drop_timeout(loop: Any) -> None: proto = ResponseHandler(loop=loop) proto.set_response_params(read_timeout=1) + proto.start_timeout() assert proto._read_timeout_handle is not None proto._drop_timeout() assert proto._read_timeout_handle is None @@ -123,6 +126,7 @@ async def test_drop_timeout(loop: Any) -> None: async def test_reschedule_timeout(loop: Any) -> None: proto = ResponseHandler(loop=loop) proto.set_response_params(read_timeout=1) + proto.start_timeout() assert proto._read_timeout_handle is not None h = proto._read_timeout_handle proto._reschedule_timeout() @@ -133,6 +137,21 @@ async def test_reschedule_timeout(loop: Any) -> None: async def test_eof_received(loop: Any) -> None: proto = ResponseHandler(loop=loop) proto.set_response_params(read_timeout=1) + proto.start_timeout() assert proto._read_timeout_handle is not None proto.eof_received() assert proto._read_timeout_handle is None + + +async def test_connection_lost_sets_transport_to_none(loop: Any, mocker: Any) -> None: + """Ensure that the transport is set to None when the connection is lost. + + This ensures the writer knows that the connection is closed. + """ + proto = ResponseHandler(loop=loop) + proto.connection_made(mocker.Mock()) + assert proto.transport is not None + + proto.connection_lost(OSError()) + + assert proto.transport is None diff --git a/tests/test_client_request.py b/tests/test_client_request.py index ecda7b0dc93..b0b11fda7e3 100644 --- a/tests/test_client_request.py +++ b/tests/test_client_request.py @@ -5,7 +5,7 @@ import pathlib import zlib from http.cookies import BaseCookie, Morsel, SimpleCookie -from typing import Any +from typing import Any, Optional from unittest import mock import pytest @@ -13,7 +13,7 @@ from yarl import URL import aiohttp -from aiohttp import BaseConnector, hdrs, payload +from aiohttp import BaseConnector, hdrs, helpers, payload from aiohttp.client_reqrep import ( ClientRequest, ClientResponse, @@ -1230,3 +1230,51 @@ def test_loose_cookies_types(loop: Any) -> None: def test_gen_default_accept_encoding(has_brotli: Any, expected: Any) -> None: with mock.patch("aiohttp.client_reqrep.HAS_BROTLI", has_brotli): assert _gen_default_accept_encoding() == expected + + +@pytest.mark.parametrize( + ("netrc_contents", "expected_auth"), + [ + ( + "machine example.com login username password pass\n", + helpers.BasicAuth("username", "pass"), + ) + ], + indirect=("netrc_contents",), +) +@pytest.mark.usefixtures("netrc_contents") +def test_basicauth_from_netrc_present( + make_request: Any, + expected_auth: Optional[helpers.BasicAuth], +): + """Test appropriate Authorization header is sent when netrc is not empty.""" + req = make_request("get", "http://example.com", trust_env=True) + assert req.headers[hdrs.AUTHORIZATION] == expected_auth.encode() + + +@pytest.mark.parametrize( + "netrc_contents", + ("machine example.com login username password pass\n",), + indirect=("netrc_contents",), +) +@pytest.mark.usefixtures("netrc_contents") +def test_basicauth_from_netrc_present_untrusted_env( + make_request: Any, +): + """Test no authorization header is sent via netrc if trust_env is False""" + req = make_request("get", "http://example.com", trust_env=False) + assert hdrs.AUTHORIZATION not in req.headers + + +@pytest.mark.parametrize( + "netrc_contents", + ("",), + indirect=("netrc_contents",), +) +@pytest.mark.usefixtures("netrc_contents") +def test_basicauth_from_empty_netrc( + make_request: Any, +): + """Test that no Authorization header is sent when netrc is empty""" + req = make_request("get", "http://example.com", trust_env=True) + assert hdrs.AUTHORIZATION not in req.headers diff --git a/tests/test_client_session.py b/tests/test_client_session.py index 7592349795f..8b76c25a3ad 100644 --- a/tests/test_client_session.py +++ b/tests/test_client_session.py @@ -4,7 +4,6 @@ import gc import io import json -import sys from http.cookies import SimpleCookie from typing import Any, List from unittest import mock @@ -160,7 +159,6 @@ async def test_merge_headers_with_list_of_tuples_duplicated_names( def test_http_GET(session: Any, params: Any) -> None: - # Python 3.8 will auto use mock.AsyncMock, it has different behavior with mock.patch( "aiohttp.client.ClientSession._request", new_callable=mock.MagicMock ) as patched: @@ -572,7 +570,6 @@ async def on_request_headers_sent(session, context, params): async with session.post( "/", data=body, trace_request_ctx=trace_request_ctx ) as resp: - await resp.json() on_request_start.assert_called_once_with( @@ -683,14 +680,7 @@ def to_url(path: str) -> URL: # Exception with mock.patch("aiohttp.client.TCPConnector.connect") as connect_patched: - error = Exception() - if sys.version_info >= (3, 8, 1): - connect_patched.side_effect = error - else: - loop = asyncio.get_event_loop() - f = loop.create_future() - f.set_exception(error) - connect_patched.return_value = f + connect_patched.side_effect = Exception() for req in [ lambda: session.get("/?x=0"), @@ -718,13 +708,7 @@ async def test_request_tracing_exception() -> None: with mock.patch("aiohttp.client.TCPConnector.connect") as connect_patched: error = Exception() - if sys.version_info >= (3, 8, 1): - connect_patched.side_effect = error - else: - loop = asyncio.get_event_loop() - f = loop.create_future() - f.set_exception(error) - connect_patched.return_value = f + connect_patched.side_effect = error session = aiohttp.ClientSession(trace_configs=[trace_config]) diff --git a/tests/test_connector.py b/tests/test_connector.py index 148b4cd742f..9ef71882f87 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -14,7 +14,6 @@ from unittest import mock import pytest -from conftest import needs_unix from yarl import URL import aiohttp @@ -762,7 +761,6 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(loop: Any) - async def test_tcp_connector_dns_throttle_requests_cancelled_when_close( loop: Any, dns_response: Any ) -> None: - with mock.patch("aiohttp.connector.DefaultResolver") as m_resolver: conn = aiohttp.TCPConnector(use_dns_cache=True, ttl_dns_cache=10) m_resolver().resolve.return_value = dns_response() @@ -789,7 +787,6 @@ async def coro(): async def test_tcp_connector_cancel_dns_error_captured( loop: Any, dns_response_error: Any ) -> None: - exception_handler_called = False def exception_handler(loop, context): @@ -918,7 +915,6 @@ async def test_tcp_connector_dns_tracing_cache_disabled( async def test_tcp_connector_dns_tracing_throttle_requests( loop: Any, dns_response: Any ) -> None: - session = mock.Mock() trace_config_ctx = mock.Mock() on_dns_cache_hit = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) @@ -982,7 +978,6 @@ async def test_release_close_do_not_add_to_pool(loop: Any, key: Any) -> None: async def test_release_close_do_not_delete_existing_connections( loop: Any, key: Any ) -> None: - proto1 = create_mocked_conn(loop) conn = aiohttp.BaseConnector() @@ -1490,7 +1485,6 @@ async def test_connect_reuseconn_tracing(loop: Any, key: Any) -> None: async def test_connect_with_limit_and_limit_per_host(loop: Any, key: Any) -> None: - proto = create_mocked_conn(loop) proto.is_connected.return_value = True @@ -1589,7 +1583,6 @@ async def f(): async def test_connect_with_limit_cancelled(loop: Any) -> None: - proto = create_mocked_conn(loop) proto.is_connected.return_value = True @@ -1869,7 +1862,6 @@ async def create_connection(req, traces=None): async def test_error_on_connection_with_cancelled_waiter(loop: Any, key: Any) -> None: - conn = aiohttp.BaseConnector(limit=1) req = mock.Mock() @@ -1931,7 +1923,7 @@ async def handler(request): assert r.status == 200 -@needs_unix +@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires UNIX sockets") async def test_unix_connector_not_found(loop: Any) -> None: connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex) @@ -1940,7 +1932,7 @@ async def test_unix_connector_not_found(loop: Any) -> None: await connector.connect(req, None, ClientTimeout()) -@needs_unix +@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires UNIX sockets") async def test_unix_connector_permission(loop: Any) -> None: loop.create_unix_connection = make_mocked_coro(raise_exception=PermissionError()) connector = aiohttp.UnixConnector("/" + uuid.uuid4().hex) @@ -2030,8 +2022,7 @@ async def handler(request): session = aiohttp.ClientSession(connector=conn) url = srv.make_url("/") - err = aiohttp.ClientConnectorCertificateError - with pytest.raises(err) as ctx: + with pytest.raises(aiohttp.ClientConnectorCertificateError) as ctx: await session.get(url) assert isinstance(ctx.value, aiohttp.ClientConnectorCertificateError) diff --git a/tests/test_cookiejar.py b/tests/test_cookiejar.py index 7b5c91851ae..3e89ba730a8 100644 --- a/tests/test_cookiejar.py +++ b/tests/test_cookiejar.py @@ -3,6 +3,7 @@ import datetime import itertools import pathlib +import pickle import unittest from http.cookies import BaseCookie, Morsel, SimpleCookie from typing import Any @@ -15,6 +16,13 @@ from aiohttp import CookieJar, DummyCookieJar +def dump_cookiejar() -> bytes: # pragma: no cover + """Create pickled data for test_pickle_format().""" + cj = CookieJar() + cj.update_cookies(cookies_to_send.__pytest_wrapped__.obj()) + return pickle.dumps(cj._cookies, pickle.HIGHEST_PROTOCOL) + + @pytest.fixture def cookies_to_send(): return SimpleCookie( @@ -31,7 +39,7 @@ def cookies_to_send(): "path3-cookie=eleventh; Domain=pathtest.com; Path=/one/two; " "path4-cookie=twelfth; Domain=pathtest.com; Path=/one/two/; " "expires-cookie=thirteenth; Domain=expirestest.com; Path=/;" - " Expires=Tue, 1 Jan 2039 12:00:00 GMT; " + " Expires=Tue, 1 Jan 2999 12:00:00 GMT; " "max-age-cookie=fourteenth; Domain=maxagetest.com; Path=/;" " Max-Age=60; " "invalid-max-age-cookie=fifteenth; Domain=invalid-values.com; " @@ -165,9 +173,7 @@ def test_path_matching() -> None: assert not test_func("/different-folder/", "/folder/") -async def test_constructor( - loop: Any, cookies_to_send: Any, cookies_to_receive: Any -) -> None: +async def test_constructor(cookies_to_send: Any, cookies_to_receive: Any) -> None: jar = CookieJar() jar.update_cookies(cookies_to_send) jar_cookies = SimpleCookie() @@ -175,11 +181,10 @@ async def test_constructor( dict.__setitem__(jar_cookies, cookie.key, cookie) expected_cookies = cookies_to_send assert jar_cookies == expected_cookies - assert jar._loop is loop async def test_constructor_with_expired( - loop: Any, cookies_to_send_with_expired: Any, cookies_to_receive: Any + cookies_to_send_with_expired: Any, cookies_to_receive: Any ) -> None: jar = CookieJar() jar.update_cookies(cookies_to_send_with_expired) @@ -188,7 +193,6 @@ async def test_constructor_with_expired( dict.__setitem__(jar_cookies, cookie.key, cookie) expected_cookies = cookies_to_send_with_expired assert jar_cookies != expected_cookies - assert jar._loop is loop async def test_save_load( @@ -520,7 +524,6 @@ def test_path_filter_root(self) -> None: ) def test_path_filter_folder(self) -> None: - cookies_sent, _ = self.request_reply_with_same_url("http://pathtest.com/one/") self.assertEqual( @@ -529,7 +532,6 @@ def test_path_filter_folder(self) -> None: ) def test_path_filter_file(self) -> None: - cookies_sent, _ = self.request_reply_with_same_url( "http://pathtest.com/one/two" ) @@ -546,7 +548,6 @@ def test_path_filter_file(self) -> None: ) def test_path_filter_subfolder(self) -> None: - cookies_sent, _ = self.request_reply_with_same_url( "http://pathtest.com/one/two/" ) @@ -564,7 +565,6 @@ def test_path_filter_subfolder(self) -> None: ) def test_path_filter_subsubfolder(self) -> None: - cookies_sent, _ = self.request_reply_with_same_url( "http://pathtest.com/one/two/three/" ) @@ -582,7 +582,6 @@ def test_path_filter_subsubfolder(self) -> None: ) def test_path_filter_different_folder(self) -> None: - cookies_sent, _ = self.request_reply_with_same_url( "http://pathtest.com/hundred/" ) @@ -782,6 +781,37 @@ async def test_cookie_jar_clear_domain() -> None: next(iterator) +def test_pickle_format(cookies_to_send) -> None: + """Test if cookiejar pickle format breaks. + + If this test fails, it may indicate that saved cookiejars will stop working. + If that happens then: + 1. Avoid releasing the change in a bugfix release. + 2. Try to include a migration script in the release notes (example below). + 3. Use dump_cookiejar() at the top of this file to update `pickled`. + + Depending on the changes made, a migration script might look like: + import pickle + with file_path.open("rb") as f: + cookies = pickle.load(f) + + morsels = [(name, m) for c in cookies.values() for name, m in c.items()] + cookies.clear() + for name, m in morsels: + cookies[(m["domain"], m["path"])][name] = m + + with file_path.open("wb") as f: + pickle.dump(cookies, f, pickle.HIGHEST_PROTOCOL) + """ + pickled = b"\x80\x05\x95\xc5\x07\x00\x00\x00\x00\x00\x00\x8c\x0bcollections\x94\x8c\x0bdefaultdict\x94\x93\x94\x8c\x0chttp.cookies\x94\x8c\x0cSimpleCookie\x94\x93\x94\x85\x94R\x94(\x8c\x00\x94\x8c\x01/\x94\x86\x94h\x05)\x81\x94\x8c\rshared-cookie\x94h\x03\x8c\x06Morsel\x94\x93\x94)\x81\x94(\x8c\x07expires\x94h\x08\x8c\x04path\x94h\t\x8c\x07comment\x94h\x08\x8c\x06domain\x94h\x08\x8c\x07max-age\x94h\x08\x8c\x06secure\x94h\x08\x8c\x08httponly\x94h\x08\x8c\x07version\x94h\x08\x8c\x08samesite\x94h\x08u}\x94(\x8c\x03key\x94h\x0c\x8c\x05value\x94\x8c\x05first\x94\x8c\x0bcoded_value\x94h\x1cubs\x8c\x0bexample.com\x94h\t\x86\x94h\x05)\x81\x94(\x8c\rdomain-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13h\x1eh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah!h\x1b\x8c\x06second\x94h\x1dh$ub\x8c\x14dotted-domain-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13\x8c\x0bexample.com\x94h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah%h\x1b\x8c\x05fifth\x94h\x1dh)ubu\x8c\x11test1.example.com\x94h\t\x86\x94h\x05)\x81\x94\x8c\x11subdomain1-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13h*h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah-h\x1b\x8c\x05third\x94h\x1dh0ubs\x8c\x11test2.example.com\x94h\t\x86\x94h\x05)\x81\x94\x8c\x11subdomain2-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13h1h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah4h\x1b\x8c\x06fourth\x94h\x1dh7ubs\x8c\rdifferent.org\x94h\t\x86\x94h\x05)\x81\x94\x8c\x17different-domain-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13h8h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah;h\x1b\x8c\x05sixth\x94h\x1dh>ubs\x8c\nsecure.com\x94h\t\x86\x94h\x05)\x81\x94\x8c\rsecure-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13h?h\x14h\x08h\x15\x88h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahBh\x1b\x8c\x07seventh\x94h\x1dhEubs\x8c\x0cpathtest.com\x94h\t\x86\x94h\x05)\x81\x94(\x8c\x0eno-path-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13hFh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahIh\x1b\x8c\x06eighth\x94h\x1dhLub\x8c\x0cpath1-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13\x8c\x0cpathtest.com\x94h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahMh\x1b\x8c\x05ninth\x94h\x1dhQubu\x8c\x0cpathtest.com\x94\x8c\x04/one\x94\x86\x94h\x05)\x81\x94\x8c\x0cpath2-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11hSh\x12h\x08h\x13hRh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahVh\x1b\x8c\x05tenth\x94h\x1dhYubs\x8c\x0cpathtest.com\x94\x8c\x08/one/two\x94\x86\x94h\x05)\x81\x94\x8c\x0cpath3-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h[h\x12h\x08h\x13hZh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah^h\x1b\x8c\x08eleventh\x94h\x1dhaubs\x8c\x0cpathtest.com\x94\x8c\t/one/two/\x94\x86\x94h\x05)\x81\x94\x8c\x0cpath4-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11hch\x12h\x08h\x13hbh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahfh\x1b\x8c\x07twelfth\x94h\x1dhiubs\x8c\x0fexpirestest.com\x94h\t\x86\x94h\x05)\x81\x94\x8c\x0eexpires-cookie\x94h\x0e)\x81\x94(h\x10\x8c\x1cTue, 1 Jan 2999 12:00:00 GMT\x94h\x11h\th\x12h\x08h\x13hjh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahmh\x1b\x8c\nthirteenth\x94h\x1dhqubs\x8c\x0emaxagetest.com\x94h\t\x86\x94h\x05)\x81\x94\x8c\x0emax-age-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13hrh\x14\x8c\x0260\x94h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ahuh\x1b\x8c\nfourteenth\x94h\x1dhyubs\x8c\x12invalid-values.com\x94h\t\x86\x94h\x05)\x81\x94(\x8c\x16invalid-max-age-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13hzh\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah}h\x1b\x8c\tfifteenth\x94h\x1dh\x80ub\x8c\x16invalid-expires-cookie\x94h\x0e)\x81\x94(h\x10h\x08h\x11h\th\x12h\x08h\x13\x8c\x12invalid-values.com\x94h\x14h\x08h\x15h\x08h\x16h\x08h\x17h\x08h\x18h\x08u}\x94(h\x1ah\x81h\x1b\x8c\tsixteenth\x94h\x1dh\x85ubuu." # noqa: E501 + cookies = pickle.loads(pickled) + + cj = CookieJar() + cj.update_cookies(cookies_to_send) + + assert cookies == cj._cookies + + @pytest.mark.parametrize( "url", [ diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 37aaff6b9f7..0f37089d59f 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -6,6 +6,7 @@ import platform import weakref from math import ceil, modf +from pathlib import Path from unittest import mock from urllib.request import getproxies_environment @@ -190,7 +191,6 @@ def test_basic_auth_from_not_url() -> None: class ReifyMixin: - reify = NotImplemented def test_reify(self) -> None: @@ -975,3 +975,86 @@ def test_populate_with_cookies(): ) def test_parse_http_date(value, expected): assert parse_http_date(value) == expected + + +@pytest.mark.parametrize( + ["netrc_contents", "expected_username"], + [ + ( + "machine example.com login username password pass\n", + "username", + ), + ], + indirect=("netrc_contents",), +) +@pytest.mark.usefixtures("netrc_contents") +def test_netrc_from_env(expected_username: str): + """Test that reading netrc files from env works as expected""" + netrc_obj = helpers.netrc_from_env() + assert netrc_obj.authenticators("example.com")[0] == expected_username + + +@pytest.fixture +def protected_dir(tmp_path: Path): + protected_dir = tmp_path / "protected" + protected_dir.mkdir() + try: + protected_dir.chmod(0o600) + yield protected_dir + finally: + protected_dir.rmdir() + + +def test_netrc_from_home_does_not_raise_if_access_denied( + protected_dir: Path, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setattr(Path, "home", lambda: protected_dir) + monkeypatch.delenv("NETRC", raising=False) + + helpers.netrc_from_env() + + +@pytest.mark.parametrize( + ["netrc_contents", "expected_auth"], + [ + ( + "machine example.com login username password pass\n", + helpers.BasicAuth("username", "pass"), + ), + ( + "machine example.com account username password pass\n", + helpers.BasicAuth("username", "pass"), + ), + ( + "machine example.com password pass\n", + helpers.BasicAuth("", "pass"), + ), + ], + indirect=("netrc_contents",), +) +@pytest.mark.usefixtures("netrc_contents") +def test_basicauth_present_in_netrc( + expected_auth: helpers.BasicAuth, +): + """Test that netrc file contents are properly parsed into BasicAuth tuples""" + netrc_obj = helpers.netrc_from_env() + + assert expected_auth == helpers.basicauth_from_netrc(netrc_obj, "example.com") + + +@pytest.mark.parametrize( + ["netrc_contents"], + [ + ("",), + ], + indirect=("netrc_contents",), +) +@pytest.mark.usefixtures("netrc_contents") +def test_read_basicauth_from_empty_netrc(): + """Test that an error is raised if netrc doesn't have an entry for our host""" + netrc_obj = helpers.netrc_from_env() + + with pytest.raises( + LookupError, match="No entry for example.com found in the `.netrc` file." + ): + helpers.basicauth_from_netrc(netrc_obj, "example.com") diff --git a/tests/test_http_exceptions.py b/tests/test_http_exceptions.py index 27aab67089f..28fdcbe0c69 100644 --- a/tests/test_http_exceptions.py +++ b/tests/test_http_exceptions.py @@ -32,13 +32,13 @@ def test_str(self) -> None: err = http_exceptions.HttpProcessingError( code=500, message="Internal error", headers={} ) - assert str(err) == "500, message='Internal error'" + assert str(err) == "500, message:\n Internal error" def test_repr(self) -> None: err = http_exceptions.HttpProcessingError( code=500, message="Internal error", headers={} ) - assert repr(err) == ("") + assert repr(err) == ("") class TestBadHttpMessage: @@ -61,7 +61,7 @@ def test_pickle(self) -> None: def test_str(self) -> None: err = http_exceptions.BadHttpMessage(message="Bad HTTP message", headers={}) - assert str(err) == "400, message='Bad HTTP message'" + assert str(err) == "400, message:\n Bad HTTP message" def test_repr(self) -> None: err = http_exceptions.BadHttpMessage(message="Bad HTTP message", headers={}) @@ -88,9 +88,8 @@ def test_pickle(self) -> None: def test_str(self) -> None: err = http_exceptions.LineTooLong(line="spam", limit="10", actual_size="12") - assert str(err) == ( - "400, message='Got more than 10 bytes (12) " "when reading spam.'" - ) + expected = "400, message:\n Got more than 10 bytes (12) when reading spam." + assert str(err) == expected def test_repr(self) -> None: err = http_exceptions.LineTooLong(line="spam", limit="10", actual_size="12") @@ -120,25 +119,24 @@ def test_pickle(self) -> None: def test_str(self) -> None: err = http_exceptions.InvalidHeader(hdr="X-Spam") - assert str(err) == "400, message='Invalid HTTP Header: X-Spam'" + assert str(err) == "400, message:\n Invalid HTTP Header: X-Spam" def test_repr(self) -> None: err = http_exceptions.InvalidHeader(hdr="X-Spam") - assert repr(err) == ( - "" - ) + expected = "" + assert repr(err) == expected class TestBadStatusLine: def test_ctor(self) -> None: err = http_exceptions.BadStatusLine("Test") assert err.line == "Test" - assert str(err) == "400, message=\"Bad status line 'Test'\"" + assert str(err) == "400, message:\n Bad status line 'Test'" def test_ctor2(self) -> None: err = http_exceptions.BadStatusLine(b"") assert err.line == "b''" - assert str(err) == "400, message='Bad status line \"b\\'\\'\"'" + assert str(err) == "400, message:\n Bad status line \"b''\"" def test_pickle(self) -> None: err = http_exceptions.BadStatusLine("Test") diff --git a/tests/test_http_parser.py b/tests/test_http_parser.py index 619a95b4b6f..8113fb94dd1 100644 --- a/tests/test_http_parser.py +++ b/tests/test_http_parser.py @@ -2,6 +2,7 @@ # Tests for aiohttp/protocol.py import asyncio +import re from typing import Any, List from unittest import mock from urllib.parse import quote @@ -117,6 +118,26 @@ def test_parse_headers(parser: Any) -> None: assert not msg.upgrade +@pytest.mark.skipif(NO_EXTENSIONS, reason="Only tests C parser.") +def test_invalid_character(loop: Any, protocol: Any, request: Any) -> None: + parser = HttpRequestParserC( + protocol, + loop, + 2**16, + max_line_size=8190, + max_field_size=8190, + ) + text = b"POST / HTTP/1.1\r\nHost: localhost:8080\r\nSet-Cookie: abc\x01def\r\n\r\n" + error_detail = re.escape( + r""": + + b'Set-Cookie: abc\x01def\r' + ^""" + ) + with pytest.raises(http_exceptions.BadHttpMessage, match=error_detail): + parser.feed_data(text) + + def test_parse_headers_longline(parser: Any) -> None: invalid_unicode_byte = b"\xd9" header_name = b"Test" + invalid_unicode_byte + b"Header" + b"A" * 8192 @@ -436,7 +457,7 @@ def test_max_header_field_size(parser: Any, size: Any) -> None: name = b"t" * size text = b"GET /test HTTP/1.1\r\n" + name + b":data\r\n\r\n" - match = f"400, message='Got more than 8190 bytes \\({size}\\) when reading" + match = f"400, message:\n Got more than 8190 bytes \\({size}\\) when reading" with pytest.raises(http_exceptions.LineTooLong, match=match): parser.feed_data(text) @@ -464,7 +485,7 @@ def test_max_header_value_size(parser: Any, size: Any) -> None: name = b"t" * size text = b"GET /test HTTP/1.1\r\n" b"data:" + name + b"\r\n\r\n" - match = f"400, message='Got more than 8190 bytes \\({size}\\) when reading" + match = f"400, message:\n Got more than 8190 bytes \\({size}\\) when reading" with pytest.raises(http_exceptions.LineTooLong, match=match): parser.feed_data(text) @@ -492,7 +513,7 @@ def test_max_header_value_size_continuation(parser: Any, size: Any) -> None: name = b"T" * (size - 5) text = b"GET /test HTTP/1.1\r\n" b"data: test\r\n " + name + b"\r\n\r\n" - match = f"400, message='Got more than 8190 bytes \\({size}\\) when reading" + match = f"400, message:\n Got more than 8190 bytes \\({size}\\) when reading" with pytest.raises(http_exceptions.LineTooLong, match=match): parser.feed_data(text) @@ -615,7 +636,7 @@ def test_http_request_parser_bad_version(parser: Any) -> None: @pytest.mark.parametrize("size", [40965, 8191]) def test_http_request_max_status_line(parser: Any, size: Any) -> None: path = b"t" * (size - 5) - match = f"400, message='Got more than 8190 bytes \\({size}\\) when reading" + match = f"400, message:\n Got more than 8190 bytes \\({size}\\) when reading" with pytest.raises(http_exceptions.LineTooLong, match=match): parser.feed_data(b"GET /path" + path + b" HTTP/1.1\r\n\r\n") @@ -660,7 +681,7 @@ def test_http_response_parser_bad_status_line_too_long( response: Any, size: Any ) -> None: reason = b"t" * (size - 2) - match = f"400, message='Got more than 8190 bytes \\({size}\\) when reading" + match = f"400, message:\n Got more than 8190 bytes \\({size}\\) when reading" with pytest.raises(http_exceptions.LineTooLong, match=match): response.feed_data(b"HTTP/1.1 200 Ok" + reason + b"\r\n\r\n") @@ -694,6 +715,7 @@ def test_http_response_parser_bad(response: Any) -> None: response.feed_data(b"HTT/1\r\n\r\n") +@pytest.mark.skipif(not NO_EXTENSIONS, reason="Behaviour has changed in C parser") def test_http_response_parser_code_under_100(response: Any) -> None: msg = response.feed_data(b"HTTP/1.1 99 test\r\n\r\n")[0][0][0] assert msg.code == 99 @@ -1119,7 +1141,7 @@ async def test_feed_data(self, stream: Any) -> None: dbuf = DeflateBuffer(buf, "deflate") dbuf.decompressor = mock.Mock() - dbuf.decompressor.decompress.return_value = b"line" + dbuf.decompressor.decompress_sync.return_value = b"line" # First byte should be b'x' in order code not to change the decoder. dbuf.feed_data(b"xxxx", 4) @@ -1133,7 +1155,7 @@ async def test_feed_data_err(self, stream: Any) -> None: exc = ValueError() dbuf.decompressor = mock.Mock() - dbuf.decompressor.decompress.side_effect = exc + dbuf.decompressor.decompress_sync.side_effect = exc with pytest.raises(http_exceptions.ContentEncodingError): # Should be more than 4 bytes to trigger deflate FSM error. diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index 3fb5531ca1d..5fb4e3b9a37 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -125,7 +125,6 @@ async def test_write_payload_chunked_filter_mutiple_chunks( async def test_write_payload_deflate_compression( protocol: Any, transport: Any, loop: Any ) -> None: - COMPRESSED = b"x\x9cKI,I\x04\x00\x04\x00\x01\x9b" write = transport.write = mock.Mock() msg = http.StreamWriter(protocol, loop) @@ -157,7 +156,6 @@ async def test_write_payload_deflate_and_chunked( async def test_write_payload_bytes_memoryview( buf: Any, protocol: Any, transport: Any, loop: Any ) -> None: - msg = http.StreamWriter(protocol, loop) mv = memoryview(b"abcd") @@ -262,6 +260,23 @@ async def test_write_to_closing_transport( await msg.write(b"After closing") +async def test_write_to_closed_transport( + protocol: Any, transport: Any, loop: Any +) -> None: + """Test that writing to a closed transport raises ConnectionResetError. + + The StreamWriter checks to see if protocol.transport is None before + writing to the transport. If it is None, it raises ConnectionResetError. + """ + msg = http.StreamWriter(protocol, loop) + + await msg.write(b"Before transport close") + protocol.transport = None + + with pytest.raises(ConnectionResetError, match="Cannot write to closing transport"): + await msg.write(b"After transport closed") + + async def test_drain(protocol: Any, transport: Any, loop: Any) -> None: msg = http.StreamWriter(protocol, loop) await msg.drain() diff --git a/tests/test_loop.py b/tests/test_loop.py index f5a4c7774e1..d9114ec200b 100644 --- a/tests/test_loop.py +++ b/tests/test_loop.py @@ -6,7 +6,6 @@ import pytest from aiohttp import web -from aiohttp.helpers import PY_38 from aiohttp.test_utils import AioHTTPTestCase, loop_context @@ -14,7 +13,6 @@ platform.system() == "Windows", reason="the test is not valid for Windows" ) async def test_subprocess_co(loop: Any) -> None: - assert PY_38 or threading.current_thread() is threading.main_thread() proc = await asyncio.create_subprocess_shell( "exit 0", stdin=asyncio.subprocess.DEVNULL, @@ -38,15 +36,11 @@ async def on_startup_hook(self, app: Any) -> None: async def test_on_startup_hook(self) -> None: self.assertTrue(self.on_startup_called) - def test_default_loop(self) -> None: - self.assertIs(self.loop, asyncio.get_event_loop_policy().get_event_loop()) - def test_default_loop(loop: Any) -> None: assert asyncio.get_event_loop_policy().get_event_loop() is loop -@pytest.mark.xfail(not PY_38, reason="ThreadedChildWatcher is only available in 3.8+") def test_setup_loop_non_main_thread() -> None: child_exc = None diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 15f9e56dfb8..4708710dff6 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -3,7 +3,6 @@ import io import json import pathlib -import sys import zlib from typing import Any, Optional from unittest import mock @@ -173,15 +172,9 @@ async def test_read_chunk_without_content_length(self, newline: Any) -> None: async def test_read_incomplete_chunk(self, newline: Any) -> None: with Stream(b"") as stream: - if sys.version_info >= (3, 8, 1): - # Workaround for a weird behavior of patch.object - def prepare(data): - return data - else: - - async def prepare(data): - return data + def prepare(data): + return data with mock.patch.object( stream, @@ -223,15 +216,9 @@ async def test_read_incomplete_body_chunked(self, newline: Any) -> None: async def test_read_boundary_with_incomplete_chunk(self, newline: Any) -> None: with Stream(b"") as stream: - if sys.version_info >= (3, 8, 1): - # Workaround for weird 3.8.1 patch.object() behavior - def prepare(data): - return data - - else: - async def prepare(data): - return data + def prepare(data): + return data with mock.patch.object( stream, @@ -394,7 +381,6 @@ async def test_read_with_content_transfer_encoding_base64( async def test_decode_with_content_transfer_encoding_base64( self, newline: Any ) -> None: - with Stream(b"VG\r\r\nltZSB0byBSZ\r\nWxheCE=%s--:--" % newline) as stream: obj = aiohttp.BodyPartReader( BOUNDARY, diff --git a/tests/test_payload.py b/tests/test_payload.py index d3ca69861e0..ca243fd0c66 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -115,6 +115,5 @@ async def gen() -> AsyncIterator[bytes]: def test_async_iterable_payload_not_async_iterable() -> None: - with pytest.raises(TypeError): payload.AsyncIterablePayload(object()) # type: ignore[arg-type] diff --git a/tests/test_proxy_functional.py b/tests/test_proxy_functional.py index 22f901a88ef..128785ffbfd 100644 --- a/tests/test_proxy_functional.py +++ b/tests/test_proxy_functional.py @@ -185,8 +185,8 @@ async def test_https_proxy_unsupported_tls_in_tls( "This support for TLS in TLS is known to be disabled " r"in the stdlib asyncio\. This is why you'll probably see " r"an error in the log below\.\n\n" - "It is possible to enable it via monkeypatching under " - r"Python 3\.7 or higher\. For more details, see:\n" + r"It is possible to enable it via monkeypatching\. " + r"For more details, see:\n" r"\* https://bugs\.python\.org/issue37179\n" r"\* https://github\.com/python/cpython/pull/28073\n\n" r"You can temporarily patch this as follows:\n" @@ -206,7 +206,10 @@ async def test_https_proxy_unsupported_tls_in_tls( r"$" ) - with pytest.warns(RuntimeWarning, match=expected_warning_text,), pytest.raises( + with pytest.warns( + RuntimeWarning, + match=expected_warning_text, + ), pytest.raises( ClientConnectionError, match=expected_exception_reason, ) as conn_err: diff --git a/tests/test_run_app.py b/tests/test_run_app.py index 46b868c3815..c3e52ef361b 100644 --- a/tests/test_run_app.py +++ b/tests/test_run_app.py @@ -9,14 +9,14 @@ import ssl import subprocess import sys -from typing import Any +import time +from typing import Any, Callable, NoReturn from unittest import mock from uuid import uuid4 import pytest -from conftest import needs_unix -from aiohttp import web +from aiohttp import ClientConnectorError, ClientSession, web from aiohttp.test_utils import make_mocked_coro from aiohttp.web_runner import BaseRunner @@ -54,12 +54,6 @@ HAS_IPV6 = False -# tokio event loop does not allow to override attributes -def skip_if_no_dict(loop: Any) -> None: - if not hasattr(loop, "__dict__"): - pytest.skip("can not override loop attributes") - - def skip_if_on_windows() -> None: if platform.system() == "Windows": pytest.skip("the test is not valid for Windows") @@ -67,7 +61,6 @@ def skip_if_on_windows() -> None: @pytest.fixture def patched_loop(loop: Any): - skip_if_no_dict(loop) server = mock.Mock() server.wait_closed = make_mocked_coro(None) loop.create_server = make_mocked_coro(server) @@ -548,7 +541,7 @@ def test_run_app_https_unix_socket(patched_loop: Any, unix_sockname: Any) -> Non assert f"https://unix:{unix_sockname}:" in printer.call_args[0][0] -@needs_unix +@pytest.mark.skipif(not hasattr(socket, "AF_UNIX"), reason="requires UNIX sockets") @skip_if_no_abstract_paths def test_run_app_abstract_linux_socket(patched_loop: Any) -> None: sock_path = b"\x00" + uuid4().hex.encode("ascii") @@ -926,3 +919,197 @@ async def init(): web.run_app(init(), print=stopper(patched_loop), loop=patched_loop) assert count == 3 + + +class TestShutdown: + def raiser(self) -> NoReturn: + raise KeyboardInterrupt + + async def stop(self, request: web.Request) -> web.Response: + asyncio.get_running_loop().call_soon(self.raiser) + return web.Response() + + def run_app(self, port: int, timeout: int, task, extra_test=None) -> asyncio.Task: + async def test() -> None: + await asyncio.sleep(1) + async with ClientSession() as sess: + async with sess.get(f"http://localhost:{port}/"): + pass + async with sess.get(f"http://localhost:{port}/stop"): + pass + + if extra_test: + await extra_test(sess) + + async def run_test(app: web.Application) -> None: + nonlocal test_task + test_task = asyncio.create_task(test()) + yield + await test_task + + async def handler(request: web.Request) -> web.Response: + nonlocal t + t = asyncio.create_task(task()) + return web.Response(text="FOO") + + t = test_task = None + app = web.Application() + app.cleanup_ctx.append(run_test) + app.router.add_get("/", handler) + app.router.add_get("/stop", self.stop) + + web.run_app(app, port=port, shutdown_timeout=timeout) + assert test_task.exception() is None + return t + + def test_shutdown_wait_for_task( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + finished = False + + async def task(): + nonlocal finished + await asyncio.sleep(2) + finished = True + + t = self.run_app(port, 3, task) + + assert finished is True + assert t.done() + assert not t.cancelled() + + def test_shutdown_timeout_task( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + finished = False + + async def task(): + nonlocal finished + await asyncio.sleep(2) + finished = True + + t = self.run_app(port, 1, task) + + assert finished is False + assert t.done() + assert t.cancelled() + + def test_shutdown_wait_for_spawned_task( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + finished = False + finished_sub = False + sub_t = None + + async def sub_task(): + nonlocal finished_sub + await asyncio.sleep(1.5) + finished_sub = True + + async def task(): + nonlocal finished, sub_t + await asyncio.sleep(0.5) + sub_t = asyncio.create_task(sub_task()) + finished = True + + t = self.run_app(port, 3, task) + + assert finished is True + assert t.done() + assert not t.cancelled() + assert finished_sub is True + assert sub_t.done() + assert not sub_t.cancelled() + + def test_shutdown_timeout_not_reached( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + finished = False + + async def task(): + nonlocal finished + await asyncio.sleep(1) + finished = True + + start_time = time.time() + t = self.run_app(port, 15, task) + + assert finished is True + assert t.done() + # Verify run_app has not waited for timeout. + assert time.time() - start_time < 10 + + def test_shutdown_new_conn_rejected( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + finished = False + + async def task() -> None: + nonlocal finished + await asyncio.sleep(9) + finished = True + + async def test(sess: ClientSession) -> None: + # Ensure we are in the middle of shutdown (waiting for task()). + await asyncio.sleep(1) + with pytest.raises(ClientConnectorError): + # Use a new session to try and open a new connection. + async with ClientSession() as sess: + async with sess.get(f"http://localhost:{port}/"): + pass + assert finished is False + + t = self.run_app(port, 10, task, test) + + assert finished is True + assert t.done() + + def test_shutdown_pending_handler_responds( + self, aiohttp_unused_port: Callable[[], int] + ) -> None: + port = aiohttp_unused_port() + finished = False + + async def test() -> None: + async def test_resp(sess): + async with sess.get(f"http://localhost:{port}/") as resp: + assert await resp.text() == "FOO" + + await asyncio.sleep(1) + async with ClientSession() as sess: + t = asyncio.create_task(test_resp(sess)) + await asyncio.sleep(1) + # Handler is in-progress while we trigger server shutdown. + async with sess.get(f"http://localhost:{port}/stop"): + pass + + assert finished is False + # Handler should still complete and produce a response. + await t + + async def run_test(app: web.Application) -> None: + nonlocal t + t = asyncio.create_task(test()) + yield + await t + + async def handler(request: web.Request) -> web.Response: + nonlocal finished + await asyncio.sleep(3) + finished = True + return web.Response(text="FOO") + + t = None + app = web.Application() + app.cleanup_ctx.append(run_test) + app.router.add_get("/", handler) + app.router.add_get("/stop", self.stop) + + web.run_app(app, port=port, shutdown_timeout=5) + assert t.exception() is None + assert finished is True diff --git a/tests/test_streams.py b/tests/test_streams.py index 47cba7b5ef5..3d2c65f827f 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -70,7 +70,6 @@ def get_memory_usage(obj: Any): class TestStreamReader: - DATA: bytes = b"line1\nline2\nline3\n" def _make_one(self, *args, **kwargs) -> streams.StreamReader: diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index 1120ebbc9db..d2e189b18d6 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -30,7 +30,6 @@ async def hello(request): return web.Response(body=_hello_world_bytes) async def websocket_handler(request): - ws = web.WebSocketResponse() await ws.prepare(request) msg = await ws.receive() @@ -88,7 +87,7 @@ async def test_aiohttp_client_close_is_idempotent() -> None: class TestAioHTTPTestCase(AioHTTPTestCase): - def get_app(self): + async def get_application(self): return _create_example_app() async def test_example_with_loop(self) -> None: @@ -97,22 +96,13 @@ async def test_example_with_loop(self) -> None: text = await request.text() assert _hello_world_str == text - def test_inner_example(self) -> None: - async def test_get_route() -> None: - resp = await self.client.request("GET", "/") - assert resp.status == 200 - text = await resp.text() - assert _hello_world_str == text - - self.loop.run_until_complete(test_get_route()) - async def test_example_without_explicit_loop(self) -> None: request = await self.client.request("GET", "/") assert request.status == 200 text = await request.text() assert _hello_world_str == text - async def test_inner_example_without_explicit_loop(self) -> None: + async def test_inner_example(self) -> None: async def test_get_route() -> None: resp = await self.client.request("GET", "/") assert resp.status == 200 @@ -292,7 +282,7 @@ def test_noop(self) -> None: """ ) result = testdir.runpytest() - result.stdout.fnmatch_lines(["*RuntimeError*"]) + result.stdout.fnmatch_lines(["*TypeError*"]) async def test_server_context_manager(app: Any, loop: Any) -> None: diff --git a/tests/test_urldispatch.py b/tests/test_urldispatch.py index e302892c7c9..8055a23c5af 100644 --- a/tests/test_urldispatch.py +++ b/tests/test_urldispatch.py @@ -755,7 +755,6 @@ def test_add_route_not_started_with_slash(router: Any) -> None: def test_add_route_invalid_method(router: Any) -> None: - sample_bad_methods = { "BAD METHOD", "B@D_METHOD", diff --git a/tests/test_web_app.py b/tests/test_web_app.py index c6c4a59ad86..11cb34c06ee 100644 --- a/tests/test_web_app.py +++ b/tests/test_web_app.py @@ -1,6 +1,5 @@ -# type: ignore import asyncio -from typing import Any, Iterator +from typing import Any, AsyncIterator, Callable, Iterator, NoReturn from unittest import mock import pytest @@ -36,7 +35,7 @@ async def test_app_register_coro() -> None: app = web.Application() fut = asyncio.get_event_loop().create_future() - async def cb(app): + async def cb(app: web.Application) -> None: await asyncio.sleep(0.001) fut.set_result(123) @@ -58,7 +57,7 @@ async def test_on_shutdown() -> None: app = web.Application() called = False - async def on_shutdown(app_param): + async def on_shutdown(app_param: web.Application) -> None: nonlocal called assert app is app_param called = True @@ -76,21 +75,21 @@ async def test_on_startup() -> None: long_running2_called = False all_long_running_called = False - async def long_running1(app_param): + async def long_running1(app_param: web.Application) -> None: nonlocal long_running1_called assert app is app_param long_running1_called = True - async def long_running2(app_param): + async def long_running2(app_param: web.Application) -> None: nonlocal long_running2_called assert app is app_param long_running2_called = True - async def on_startup_all_long_running(app_param): + async def on_startup_all_long_running(app_param: web.Application) -> None: nonlocal all_long_running_called assert app is app_param all_long_running_called = True - return await asyncio.gather(long_running1(app_param), long_running2(app_param)) + await asyncio.gather(long_running1(app_param), long_running2(app_param)) app.on_startup.append(on_startup_all_long_running) app.freeze() @@ -117,8 +116,8 @@ def test_appkey_repr_concrete() -> None: "", # pytest-xdist "", ) - key = web.AppKey("key", web.Request) - assert repr(key) in ( + key2 = web.AppKey("key", web.Request) + assert repr(key2) in ( # pytest-xdist: "", "", @@ -180,14 +179,13 @@ def test_equality() -> None: def test_app_run_middlewares() -> None: - root = web.Application() sub = web.Application() root.add_subapp("/sub", sub) root.freeze() assert root._run_middlewares is False - async def middleware(request, handler: Handler): + async def middleware(request: web.Request, handler: Handler) -> web.StreamResponse: return await handler(request) root = web.Application(middlewares=[middleware]) @@ -215,22 +213,22 @@ def test_subapp_pre_frozen_after_adding() -> None: def test_app_inheritance() -> None: with pytest.raises(TypeError): - class A(web.Application): + class A(web.Application): # type: ignore[misc] pass def test_app_custom_attr() -> None: app = web.Application() with pytest.raises(AttributeError): - app.custom = None + app.custom = None # type: ignore[attr-defined] async def test_cleanup_ctx() -> None: app = web.Application() out = [] - def f(num): - async def inner(app): + def f(num: int) -> Callable[[web.Application], AsyncIterator[None]]: + async def inner(app: web.Application) -> AsyncIterator[None]: out.append("pre_" + str(num)) yield None out.append("post_" + str(num)) @@ -252,8 +250,10 @@ async def test_cleanup_ctx_exception_on_startup() -> None: exc = Exception("fail") - def f(num, fail=False): - async def inner(app): + def f( + num: int, fail: bool = False + ) -> Callable[[web.Application], AsyncIterator[None]]: + async def inner(app: web.Application) -> AsyncIterator[None]: out.append("pre_" + str(num)) if fail: raise exc @@ -280,8 +280,10 @@ async def test_cleanup_ctx_exception_on_cleanup() -> None: exc = Exception("fail") - def f(num, fail=False): - async def inner(app): + def f( + num: int, fail: bool = False + ) -> Callable[[web.Application], AsyncIterator[None]]: + async def inner(app: web.Application) -> AsyncIterator[None]: out.append("pre_" + str(num)) yield None out.append("post_" + str(num)) @@ -306,13 +308,13 @@ async def test_cleanup_ctx_cleanup_after_exception() -> None: app = web.Application() ctx_state = None - async def success_ctx(app): + async def success_ctx(app: web.Application) -> AsyncIterator[None]: nonlocal ctx_state ctx_state = "START" yield ctx_state = "CLEAN" - async def fail_ctx(app): + async def fail_ctx(app: web.Application) -> AsyncIterator[NoReturn]: raise Exception() yield @@ -332,8 +334,10 @@ async def test_cleanup_ctx_exception_on_cleanup_multiple() -> None: app = web.Application() out = [] - def f(num, fail=False): - async def inner(app): + def f( + num: int, fail: bool = False + ) -> Callable[[web.Application], AsyncIterator[None]]: + async def inner(app: web.Application) -> AsyncIterator[None]: out.append("pre_" + str(num)) yield None out.append("post_" + str(num)) @@ -361,8 +365,8 @@ async def test_cleanup_ctx_multiple_yields() -> None: app = web.Application() out = [] - def f(num): - async def inner(app): + def f(num: int) -> Callable[[web.Application], AsyncIterator[None]]: + async def inner(app: web.Application) -> AsyncIterator[None]: out.append("pre_" + str(num)) yield None out.append("post_" + str(num)) @@ -384,7 +388,7 @@ async def test_subapp_chained_config_dict_visibility(aiohttp_client: Any) -> Non key1 = web.AppKey("key1", str) key2 = web.AppKey("key2", str) - async def main_handler(request): + async def main_handler(request: web.Request) -> web.Response: assert request.config_dict[key1] == "val1" assert key2 not in request.config_dict return web.Response(status=200) @@ -393,7 +397,7 @@ async def main_handler(request): root[key1] = "val1" root.add_routes([web.get("/", main_handler)]) - async def sub_handler(request): + async def sub_handler(request: web.Request) -> web.Response: assert request.config_dict[key1] == "val1" assert request.config_dict[key2] == "val2" return web.Response(status=201) @@ -414,7 +418,7 @@ async def sub_handler(request): async def test_subapp_chained_config_dict_overriding(aiohttp_client: Any) -> None: key = web.AppKey("key", str) - async def main_handler(request): + async def main_handler(request: web.Request) -> web.Response: assert request.config_dict[key] == "val1" return web.Response(status=200) @@ -422,7 +426,7 @@ async def main_handler(request): root[key] = "val1" root.add_routes([web.get("/", main_handler)]) - async def sub_handler(request): + async def sub_handler(request: web.Request) -> web.Response: assert request.config_dict[key] == "val2" return web.Response(status=201) @@ -446,7 +450,7 @@ async def test_subapp_on_startup(aiohttp_client: Any) -> None: startup_called = False - async def on_startup(app): + async def on_startup(app: web.Application) -> None: nonlocal startup_called startup_called = True app[startup] = True @@ -456,7 +460,7 @@ async def on_startup(app): ctx_pre_called = False ctx_post_called = False - async def cleanup_ctx(app): + async def cleanup_ctx(app: web.Application) -> AsyncIterator[None]: nonlocal ctx_pre_called, ctx_post_called ctx_pre_called = True app[cleanup] = True @@ -467,7 +471,7 @@ async def cleanup_ctx(app): shutdown_called = False - async def on_shutdown(app): + async def on_shutdown(app: web.Application) -> None: nonlocal shutdown_called shutdown_called = True @@ -475,7 +479,7 @@ async def on_shutdown(app): cleanup_called = False - async def on_cleanup(app): + async def on_cleanup(app: web.Application) -> None: nonlocal cleanup_called cleanup_called = True @@ -529,9 +533,9 @@ def test_app_iter() -> None: def test_app_forbid_nonslot_attr() -> None: app = web.Application() with pytest.raises(AttributeError): - app.unknow_attr + app.unknow_attr # type: ignore[attr-defined] with pytest.raises(AttributeError): - app.unknow_attr = 1 + app.unknow_attr = 1 # type: ignore[attr-defined] def test_forbid_changing_frozen_app() -> None: diff --git a/tests/test_web_exceptions.py b/tests/test_web_exceptions.py index de3b0da4b8a..2c9e2d32d2e 100644 --- a/tests/test_web_exceptions.py +++ b/tests/test_web_exceptions.py @@ -1,13 +1,13 @@ -# type: ignore import collections import pickle from traceback import format_exception -from typing import Any +from typing import Mapping, NoReturn import pytest from yarl import URL from aiohttp import web +from aiohttp.pytest_plugin import AiohttpClient def test_all_http_exceptions_exported() -> None: @@ -23,7 +23,8 @@ def test_all_http_exceptions_exported() -> None: async def test_ctor() -> None: resp = web.HTTPOk() assert resp.text == "200: OK" - assert resp.headers == {"Content-Type": "text/plain"} + compare: Mapping[str, str] = {"Content-Type": "text/plain"} + assert resp.headers == compare assert resp.reason == "OK" assert resp.status == 200 assert bool(resp) @@ -32,7 +33,8 @@ async def test_ctor() -> None: async def test_ctor_with_headers() -> None: resp = web.HTTPOk(headers={"X-Custom": "value"}) assert resp.text == "200: OK" - assert resp.headers == {"Content-Type": "text/plain", "X-Custom": "value"} + compare: Mapping[str, str] = {"Content-Type": "text/plain", "X-Custom": "value"} + assert resp.headers == compare assert resp.reason == "OK" assert resp.status == 200 @@ -40,7 +42,8 @@ async def test_ctor_with_headers() -> None: async def test_ctor_content_type() -> None: resp = web.HTTPOk(text="text", content_type="custom") assert resp.text == "text" - assert resp.headers == {"Content-Type": "custom"} + compare: Mapping[str, str] = {"Content-Type": "custom"} + assert resp.headers == compare assert resp.reason == "OK" assert resp.status == 200 assert bool(resp) @@ -53,7 +56,8 @@ async def test_ctor_content_type_without_text() -> None: ): resp = web.HTTPResetContent(content_type="custom") assert resp.text is None - assert resp.headers == {"Content-Type": "custom"} + compare: Mapping[str, str] = {"Content-Type": "custom"} + assert resp.headers == compare assert resp.reason == "Reset Content" assert resp.status == 205 assert bool(resp) @@ -67,7 +71,8 @@ async def test_ctor_text_for_empty_body() -> None: ): resp = web.HTTPResetContent(text="text") assert resp.text == "text" - assert resp.headers == {"Content-Type": "text/plain"} + compare: Mapping[str, str] = {"Content-Type": "text/plain"} + assert resp.headers == compare assert resp.reason == "Reset Content" assert resp.status == 205 @@ -139,7 +144,8 @@ def test_ctor_all(self) -> None: content_type="custom", ) assert resp.text == "text" - assert resp.headers == {"X-Custom": "value", "Content-Type": "custom"} + compare: Mapping[str, str] = {"X-Custom": "value", "Content-Type": "custom"} + assert resp.headers == compare assert resp.reason == "Done" assert resp.status == 200 @@ -150,7 +156,7 @@ def test_pickle(self) -> None: text="text", content_type="custom", ) - resp.foo = "bar" + resp.foo = "bar" # type: ignore[attr-defined] for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(resp, proto) resp2 = pickle.loads(pickled) @@ -160,8 +166,8 @@ def test_pickle(self) -> None: assert resp2.status == 200 assert resp2.foo == "bar" - async def test_app(self, aiohttp_client: Any) -> None: - async def handler(request): + async def test_app(self, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> NoReturn: raise web.HTTPOk() app = web.Application() @@ -189,7 +195,7 @@ def test_empty_location(self) -> None: with pytest.raises(ValueError): web.HTTPFound(location="") with pytest.raises(ValueError): - web.HTTPFound(location=None) + web.HTTPFound(location=None) # type: ignore[arg-type] def test_location_CRLF(self) -> None: exc = web.HTTPFound(location="/redirect\r\n") @@ -203,7 +209,7 @@ def test_pickle(self) -> None: text="text", content_type="custom", ) - resp.foo = "bar" + resp.foo = "bar" # type: ignore[attr-defined] for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(resp, proto) resp2 = pickle.loads(pickled) @@ -214,8 +220,8 @@ def test_pickle(self) -> None: assert resp2.status == 302 assert resp2.foo == "bar" - async def test_app(self, aiohttp_client: Any) -> None: - async def handler(request): + async def test_app(self, aiohttp_client: AiohttpClient) -> None: + async def handler(request: web.Request) -> NoReturn: raise web.HTTPFound(location="/redirect") app = web.Application() @@ -242,11 +248,12 @@ async def test_ctor(self) -> None: assert resp.method == "GET" assert resp.allowed_methods == {"POST", "PUT"} assert resp.text == "text" - assert resp.headers == { + compare: Mapping[str, str] = { "X-Custom": "value", "Content-Type": "custom", "Allow": "POST,PUT", } + assert resp.headers == compare assert resp.reason == "Unsupported" assert resp.status == 405 @@ -259,7 +266,7 @@ def test_pickle(self) -> None: text="text", content_type="custom", ) - resp.foo = "bar" + resp.foo = "bar" # type: ignore[attr-defined] for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(resp, proto) resp2 = pickle.loads(pickled) @@ -283,7 +290,8 @@ def test_ctor(self) -> None: assert resp.text == ( "Maximum request body size 100 exceeded, " "actual body size 123" ) - assert resp.headers == {"X-Custom": "value", "Content-Type": "text/plain"} + compare: Mapping[str, str] = {"X-Custom": "value", "Content-Type": "text/plain"} + assert resp.headers == compare assert resp.reason == "Too large" assert resp.status == 413 @@ -291,7 +299,7 @@ def test_pickle(self) -> None: resp = web.HTTPRequestEntityTooLarge( 100, actual_size=123, headers={"X-Custom": "value"}, reason="Too large" ) - resp.foo = "bar" + resp.foo = "bar" # type: ignore[attr-defined] for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(resp, proto) resp2 = pickle.loads(pickled) @@ -313,11 +321,12 @@ def test_ctor(self) -> None: ) assert resp.link == URL("http://warning.or.kr/") assert resp.text == "text" - assert resp.headers == { + compare: Mapping[str, str] = { "X-Custom": "value", "Content-Type": "custom", "Link": '; rel="blocked-by"', } + assert resp.headers == compare assert resp.reason == "Zaprescheno" assert resp.status == 451 @@ -329,7 +338,7 @@ def test_pickle(self) -> None: text="text", content_type="custom", ) - resp.foo = "bar" + resp.foo = "bar" # type: ignore[attr-defined] for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): pickled = pickle.dumps(resp, proto) resp2 = pickle.loads(pickled) diff --git a/tests/test_web_functional.py b/tests/test_web_functional.py index cbeb015e7e8..5641d2c9096 100644 --- a/tests/test_web_functional.py +++ b/tests/test_web_functional.py @@ -191,7 +191,6 @@ async def handler(request): async def test_post_json(aiohttp_client: Any) -> None: - dct = {"key": "текст"} async def handler(request): @@ -318,7 +317,6 @@ async def handler(request): async def test_post_single_file(aiohttp_client: Any) -> None: - here = pathlib.Path(__file__).parent def check_file(fs): @@ -386,7 +384,6 @@ async def handler(request): async def test_post_files(aiohttp_client: Any) -> None: - here = pathlib.Path(__file__).parent def check_file(fs): @@ -523,7 +520,6 @@ async def handler(request): async def test_100_continue_custom(aiohttp_client: Any) -> None: - expect_received = False async def handler(request): @@ -603,7 +599,6 @@ async def expect_handler(request: web.Request) -> Optional[web.Response]: async def test_100_continue_for_not_found(aiohttp_client: Any) -> None: - app = web.Application() client = await aiohttp_client(app) @@ -697,7 +692,6 @@ async def handler(request): async def test_upload_file(aiohttp_client: Any) -> None: - here = pathlib.Path(__file__).parent fname = here / "aiohttp.png" with fname.open("rb") as f: @@ -844,7 +838,6 @@ async def handler(request): async def test_response_with_async_gen(aiohttp_client: Any, fname: Any) -> None: - with fname.open("rb") as f: data = f.read() @@ -877,7 +870,6 @@ async def handler(request): async def test_response_with_async_gen_no_params( aiohttp_client: Any, fname: Any ) -> None: - with fname.open("rb") as f: data = f.read() @@ -1109,7 +1101,6 @@ async def handler(request): async def test_start_without_routes(aiohttp_client: Any) -> None: - app = web.Application() client = await aiohttp_client(app) @@ -1895,7 +1886,6 @@ async def handler(request): async def test_iter_any(aiohttp_server: Any) -> None: - data = b"0123456789" * 1024 async def handler(request): @@ -1915,7 +1905,6 @@ async def handler(request): async def test_request_tracing(aiohttp_server: Any) -> None: - on_request_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) on_request_end = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) on_dns_resolvehost_start = mock.Mock(side_effect=make_mocked_coro(mock.Mock())) diff --git a/tests/test_web_log.py b/tests/test_web_log.py index fa5fb27f744..6a09479f4b7 100644 --- a/tests/test_web_log.py +++ b/tests/test_web_log.py @@ -1,5 +1,6 @@ # type: ignore import datetime +import logging import platform import sys from typing import Any @@ -251,3 +252,14 @@ def log(self, request, response, time): resp = await client.get("/") assert 200 == resp.status assert msg == "contextvars: uuid" + + +def test_logger_does_nothing_when_disabled(caplog: pytest.LogCaptureFixture) -> None: + """Test that the logger does nothing when the log level is disabled.""" + mock_logger = logging.getLogger("test.aiohttp.log") + mock_logger.setLevel(logging.INFO) + access_logger = AccessLogger(mock_logger, "%b") + access_logger.log( + mock.Mock(name="mock_request"), mock.Mock(name="mock_response"), 42 + ) + assert "mock_response" in caplog.text diff --git a/tests/test_web_request_handler.py b/tests/test_web_request_handler.py index 005cfaaecb7..06f99be76c0 100644 --- a/tests/test_web_request_handler.py +++ b/tests/test_web_request_handler.py @@ -1,12 +1,10 @@ -# type: ignore -from typing import Any from unittest import mock from aiohttp import web from aiohttp.test_utils import make_mocked_coro -async def serve(request: Any): +async def serve(request: web.BaseRequest) -> web.Response: return web.Response() @@ -16,8 +14,8 @@ async def test_repr() -> None: assert "" == repr(handler) - handler.transport = object() - assert "" == repr(handler) + with mock.patch.object(handler, "transport", autospec=True): + assert "" == repr(handler) async def test_connections() -> None: @@ -26,10 +24,10 @@ async def test_connections() -> None: handler = object() transport = object() - manager.connection_made(handler, transport) + manager.connection_made(handler, transport) # type: ignore[arg-type] assert manager.connections == [handler] - manager.connection_lost(handler, None) + manager.connection_lost(handler, None) # type: ignore[arg-type] assert manager.connections == [] diff --git a/tests/test_web_response.py b/tests/test_web_response.py index 44d81f112d4..5aeb2a085b5 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -603,7 +603,6 @@ async def write_headers(status_line, headers): req = make_request("GET", "/", writer=writer) payload = BytesPayload(b"answer", headers={"X-Test-Header": "test"}) resp = Response(body=payload) - assert resp.content_length == 6 resp.body = payload resp.enable_compression(ContentCoding.gzip) await resp.prepare(req) diff --git a/tests/test_web_sendfile_functional.py b/tests/test_web_sendfile_functional.py index 90ce3a0384a..6a64f8aa231 100644 --- a/tests/test_web_sendfile_functional.py +++ b/tests/test_web_sendfile_functional.py @@ -145,7 +145,6 @@ async def test_static_file_ok_string_path( async def test_static_file_not_exists(aiohttp_client: Any) -> None: - app = web.Application() client = await aiohttp_client(app) @@ -156,7 +155,6 @@ async def test_static_file_not_exists(aiohttp_client: Any) -> None: async def test_static_file_name_too_long(aiohttp_client: Any) -> None: - app = web.Application() client = await aiohttp_client(app) @@ -167,7 +165,6 @@ async def test_static_file_name_too_long(aiohttp_client: Any) -> None: async def test_static_file_upper_directory(aiohttp_client: Any) -> None: - app = web.Application() client = await aiohttp_client(app) diff --git a/tests/test_web_server.py b/tests/test_web_server.py index b97e0fa7b64..3e7eff2ad8c 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -1,5 +1,6 @@ # type: ignore import asyncio +from contextlib import suppress from typing import Any from unittest import mock @@ -207,3 +208,80 @@ async def handler(request): ) logger.exception.assert_called_with("Error handling request", exc_info=exc) + + +async def test_handler_cancellation(aiohttp_unused_port) -> None: + event = asyncio.Event() + port = aiohttp_unused_port() + + async def on_request(_: web.Request) -> web.Response: + nonlocal event + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + event.set() + raise + else: + raise web.HTTPInternalServerError() + + app = web.Application() + app.router.add_route("GET", "/", on_request) + + runner = web.AppRunner(app, handler_cancellation=True) + await runner.setup() + + site = web.TCPSite(runner, host="localhost", port=port) + + await site.start() + + try: + assert runner.server.handler_cancellation, "Flag was not propagated" + + async with client.ClientSession( + timeout=client.ClientTimeout(total=0.1) + ) as sess: + with pytest.raises(asyncio.TimeoutError): + await sess.get(f"http://localhost:{port}/") + + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(event.wait(), timeout=1) + assert event.is_set(), "Request handler hasn't been cancelled" + finally: + await asyncio.gather(runner.shutdown(), site.stop()) + + +async def test_no_handler_cancellation(aiohttp_unused_port) -> None: + timeout_event = asyncio.Event() + done_event = asyncio.Event() + port = aiohttp_unused_port() + + async def on_request(_: web.Request) -> web.Response: + nonlocal done_event, timeout_event + await asyncio.wait_for(timeout_event.wait(), timeout=5) + done_event.set() + return web.Response() + + app = web.Application() + app.router.add_route("GET", "/", on_request) + + runner = web.AppRunner(app) + await runner.setup() + + site = web.TCPSite(runner, host="localhost", port=port) + + await site.start() + + try: + async with client.ClientSession( + timeout=client.ClientTimeout(total=0.1) + ) as sess: + with pytest.raises(asyncio.TimeoutError): + await sess.get(f"http://localhost:{port}/") + await asyncio.sleep(0.1) + timeout_event.set() + + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(done_event.wait(), timeout=1) + assert done_event.is_set() + finally: + await asyncio.gather(runner.shutdown(), site.stop()) diff --git a/tests/test_web_urldispatcher.py b/tests/test_web_urldispatcher.py index ce2ec3cc77f..299da051554 100644 --- a/tests/test_web_urldispatcher.py +++ b/tests/test_web_urldispatcher.py @@ -1,7 +1,6 @@ -# type: ignore import asyncio import pathlib -from typing import Any +from typing import Optional from unittest import mock from unittest.mock import MagicMock @@ -9,6 +8,7 @@ import yarl from aiohttp import web +from aiohttp.pytest_plugin import AiohttpClient from aiohttp.web_urldispatcher import SystemRoute @@ -41,12 +41,12 @@ ], ) async def test_access_root_of_static_handler( - tmp_path: Any, - aiohttp_client: Any, - show_index: Any, - status: Any, - prefix: Any, - data: Any, + tmp_path: pathlib.Path, + aiohttp_client: AiohttpClient, + show_index: bool, + status: int, + prefix: str, + data: Optional[bytes], ) -> None: # Tests the operation of static file server. # Try to access the root of static file server, and make @@ -79,7 +79,9 @@ async def test_access_root_of_static_handler( assert read_ == data -async def test_follow_symlink(tmp_path: Any, aiohttp_client: Any) -> None: +async def test_follow_symlink( + tmp_path: pathlib.Path, aiohttp_client: AiohttpClient +) -> None: # Tests the access to a symlink, in static folder data = "hello world" @@ -113,7 +115,11 @@ async def test_follow_symlink(tmp_path: Any, aiohttp_client: Any) -> None: ], ) async def test_access_to_the_file_with_spaces( - tmp_path: Any, aiohttp_client: Any, dir_name: Any, filename: Any, data: Any + tmp_path: pathlib.Path, + aiohttp_client: AiohttpClient, + dir_name: str, + filename: str, + data: str, ) -> None: # Checks operation of static files with spaces @@ -138,7 +144,9 @@ async def test_access_to_the_file_with_spaces( await r.release() -async def test_access_non_existing_resource(tmp_path: Any, aiohttp_client: Any) -> None: +async def test_access_non_existing_resource( + tmp_path: pathlib.Path, aiohttp_client: AiohttpClient +) -> None: # Tests accessing non-existing resource # Try to access a non-exiting resource and make sure that 404 HTTP status # returned. @@ -162,12 +170,12 @@ async def test_access_non_existing_resource(tmp_path: Any, aiohttp_client: Any) ], ) async def test_url_escaping( - aiohttp_client: Any, registered_path: Any, request_url: Any + aiohttp_client: AiohttpClient, registered_path: str, request_url: str ) -> None: # Tests accessing a resource with app = web.Application() - async def handler(request): + async def handler(request: web.Request) -> web.Response: return web.Response() app.router.add_get(registered_path, handler) @@ -182,7 +190,7 @@ async def test_handler_metadata_persistence() -> None: # router. app = web.Application() - async def async_handler(request): + async def async_handler(request: web.Request) -> web.Response: """Doc""" return web.Response() @@ -193,7 +201,9 @@ async def async_handler(request): assert route.handler.__doc__ == "Doc" -async def test_unauthorized_folder_access(tmp_path: Any, aiohttp_client: Any) -> None: +async def test_unauthorized_folder_access( + tmp_path: pathlib.Path, aiohttp_client: AiohttpClient +) -> None: # Tests the unauthorized access to a folder of static file server. # Try to list a folder content of static file server when server does not # have permissions to do so for the folder. @@ -218,7 +228,9 @@ async def test_unauthorized_folder_access(tmp_path: Any, aiohttp_client: Any) -> assert r.status == 403 -async def test_access_symlink_loop(tmp_path: Any, aiohttp_client: Any) -> None: +async def test_access_symlink_loop( + tmp_path: pathlib.Path, aiohttp_client: AiohttpClient +) -> None: # Tests the access to a looped symlink, which could not be resolved. my_dir_path = tmp_path / "my_symlink" pathlib.Path(str(my_dir_path)).symlink_to(str(my_dir_path), True) @@ -234,7 +246,9 @@ async def test_access_symlink_loop(tmp_path: Any, aiohttp_client: Any) -> None: assert r.status == 404 -async def test_access_special_resource(tmp_path: Any, aiohttp_client: Any) -> None: +async def test_access_special_resource( + tmp_path: pathlib.Path, aiohttp_client: AiohttpClient +) -> None: # Tests the access to a resource that is neither a file nor a directory. # Checks that if a special resource is accessed (f.e. named pipe or UNIX # domain socket) then 404 HTTP status returned. @@ -261,7 +275,9 @@ async def test_access_special_resource(tmp_path: Any, aiohttp_client: Any) -> No assert r.status == 403 -async def test_static_head(tmp_path: Any, aiohttp_client: Any) -> None: +async def test_static_head( + tmp_path: pathlib.Path, aiohttp_client: AiohttpClient +) -> None: # Test HEAD on static route my_file_path = tmp_path / "test.txt" with my_file_path.open("wb") as fw: @@ -299,11 +315,11 @@ def test_system_route() -> None: assert "test" == route.reason -async def test_allow_head(aiohttp_client: Any) -> None: +async def test_allow_head(aiohttp_client: AiohttpClient) -> None: # Test allow_head on routes. app = web.Application() - async def handler(_): + async def handler(request: web.Request) -> web.Response: return web.Response() app.router.add_get("/a", handler, name="a") @@ -334,12 +350,12 @@ async def handler(_): "/{a}", ], ) -def test_reuse_last_added_resource(path: Any) -> None: +def test_reuse_last_added_resource(path: str) -> None: # Test that adding a route with the same name and path of the last added # resource doesn't create a new resource. app = web.Application() - async def handler(request): + async def handler(request: web.Request) -> web.Response: return web.Response() app.router.add_get(path, handler, name="a") @@ -351,27 +367,29 @@ async def handler(request): def test_resource_raw_match() -> None: app = web.Application() - async def handler(request): + async def handler(request: web.Request) -> web.Response: return web.Response() route = app.router.add_get("/a", handler, name="a") + assert route.resource is not None assert route.resource.raw_match("/a") route = app.router.add_get("/{b}", handler, name="b") + assert route.resource is not None assert route.resource.raw_match("/{b}") resource = app.router.add_static("/static", ".") assert not resource.raw_match("/static") -async def test_add_view(aiohttp_client: Any) -> None: +async def test_add_view(aiohttp_client: AiohttpClient) -> None: app = web.Application() class MyView(web.View): - async def get(self): + async def get(self) -> web.Response: return web.Response() - async def post(self): + async def post(self) -> web.Response: return web.Response() app.router.add_view("/a", MyView) @@ -391,15 +409,15 @@ async def post(self): await r.release() -async def test_decorate_view(aiohttp_client: Any) -> None: +async def test_decorate_view(aiohttp_client: AiohttpClient) -> None: routes = web.RouteTableDef() @routes.view("/a") class MyView(web.View): - async def get(self): + async def get(self) -> web.Response: return web.Response() - async def post(self): + async def post(self) -> web.Response: return web.Response() app = web.Application() @@ -420,14 +438,14 @@ async def post(self): await r.release() -async def test_web_view(aiohttp_client: Any) -> None: +async def test_web_view(aiohttp_client: AiohttpClient) -> None: app = web.Application() class MyView(web.View): - async def get(self): + async def get(self) -> web.Response: return web.Response() - async def post(self): + async def post(self) -> web.Response: return web.Response() app.router.add_routes([web.view("/a", MyView)]) @@ -447,7 +465,9 @@ async def post(self): await r.release() -async def test_static_absolute_url(aiohttp_client: Any, tmp_path: Any) -> None: +async def test_static_absolute_url( + aiohttp_client: AiohttpClient, tmp_path: pathlib.Path +) -> None: # requested url is an absolute name like # /static/\\machine_name\c$ or /static/D:\path # where the static dir is totally different @@ -461,11 +481,13 @@ async def test_static_absolute_url(aiohttp_client: Any, tmp_path: Any) -> None: assert resp.status == 403 -async def test_for_issue_5250(aiohttp_client: Any, tmp_path: Any) -> None: +async def test_for_issue_5250( + aiohttp_client: AiohttpClient, tmp_path: pathlib.Path +) -> None: app = web.Application() app.router.add_static("/foo", tmp_path) - async def get_foobar(request): + async def get_foobar(request: web.Request) -> web.Response: return web.Response(body="success!") app.router.add_get("/foobar", get_foobar) @@ -490,14 +512,14 @@ async def get_foobar(request): ids=("urldecoded_route", "urldecoded_route_with_regex", "urlencoded_route"), ) async def test_decoded_url_match( - aiohttp_client, - route_definition, - urlencoded_path, - expected_http_resp_status, + aiohttp_client: AiohttpClient, + route_definition: str, + urlencoded_path: str, + expected_http_resp_status: int, ) -> None: app = web.Application() - async def handler(_): + async def handler(request: web.Request) -> web.Response: return web.Response() app.router.add_get(route_definition, handler) diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index 65f50991d54..093cf549cf6 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -364,7 +364,6 @@ async def test_concurrent_receive(make_request: Any) -> None: async def test_close_exc(make_request: Any) -> None: - req = make_request("GET", "/") ws = WebSocketResponse() await ws.prepare(req) diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index 799379e3602..7071a0335f8 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -131,7 +131,6 @@ async def handler(request): async def test_send_recv_text(loop: Any, aiohttp_client: Any) -> None: - closed = loop.create_future() async def handler(request): @@ -165,7 +164,6 @@ async def handler(request): async def test_send_recv_bytes(loop: Any, aiohttp_client: Any) -> None: - closed = loop.create_future() async def handler(request): @@ -275,7 +273,6 @@ async def handler(request): async def test_concurrent_close(loop: Any, aiohttp_client: Any) -> None: - srv_ws = None async def handler(request): @@ -313,7 +310,6 @@ async def handler(request): async def test_auto_pong_with_closing_by_peer(loop: Any, aiohttp_client: Any) -> None: - closed = loop.create_future() async def handler(request): @@ -343,7 +339,6 @@ async def handler(request): async def test_ping(loop: Any, aiohttp_client: Any) -> None: - closed = loop.create_future() async def handler(request): @@ -370,7 +365,6 @@ async def handler(request): async def aiohttp_client_ping(loop: Any, aiohttp_client: Any): - closed = loop.create_future() async def handler(request): @@ -396,7 +390,6 @@ async def handler(request): async def test_pong(loop: Any, aiohttp_client: Any) -> None: - closed = loop.create_future() async def handler(request): @@ -431,7 +424,6 @@ async def handler(request): async def test_change_status(loop: Any, aiohttp_client: Any) -> None: - closed = loop.create_future() async def handler(request): @@ -456,7 +448,6 @@ async def handler(request): async def test_handle_protocol(loop: Any, aiohttp_client: Any) -> None: - closed = loop.create_future() async def handler(request): @@ -478,7 +469,6 @@ async def handler(request): async def test_server_close_handshake(loop: Any, aiohttp_client: Any) -> None: - closed = loop.create_future() async def handler(request): @@ -501,7 +491,6 @@ async def handler(request): async def aiohttp_client_close_handshake(loop: Any, aiohttp_client: Any): - closed = loop.create_future() async def handler(request): @@ -674,7 +663,6 @@ async def handler(request): async with aiohttp.ClientSession() as sm: async with sm.ws_connect(server.make_url("/")) as resp: - items = ["q1", "q2", "q3"] for item in items: await resp.send_str(item) @@ -687,7 +675,6 @@ async def handler(request): async def test_closed_async_for(loop: Any, aiohttp_client: Any) -> None: - closed = loop.create_future() async def handler(request): diff --git a/tests/test_worker.py b/tests/test_worker.py index 5f973179228..34e0234a113 100644 --- a/tests/test_worker.py +++ b/tests/test_worker.py @@ -1,39 +1,37 @@ -# type: ignore # Tests for aiohttp/worker.py import asyncio import os import socket import ssl +from typing import TYPE_CHECKING, Callable, Dict, Optional from unittest import mock import pytest +from _pytest.fixtures import SubRequest from aiohttp import web -base_worker = pytest.importorskip("aiohttp.worker") +if TYPE_CHECKING: + from aiohttp import worker as base_worker +else: + base_worker = pytest.importorskip("aiohttp.worker") try: import uvloop except ImportError: - uvloop = None + uvloop = None # type: ignore[assignment] WRONG_LOG_FORMAT = '%a "%{Referrer}i" %(h)s %(l)s %s' ACCEPTABLE_LOG_FORMAT = '%a "%{Referrer}i" %s' -# tokio event loop does not allow to override attributes -def skip_if_no_dict(loop): - if not hasattr(loop, "__dict__"): - pytest.skip("can not override loop attributes") - - class BaseTestWorker: - def __init__(self): - self.servers = {} + def __init__(self) -> None: + self.servers: Dict[object, object] = {} self.exit_code = 0 - self._notify_waiter = None + self._notify_waiter: Optional[asyncio.Future[bool]] = None self.cfg = mock.Mock() self.cfg.graceful_timeout = 100 self.pid = "pid" @@ -54,14 +52,16 @@ class UvloopWorker(BaseTestWorker, base_worker.GunicornUVLoopWebWorker): @pytest.fixture(params=PARAMS) -def worker(request, loop): +def worker( + request: SubRequest, loop: asyncio.AbstractEventLoop +) -> base_worker.GunicornWebWorker: asyncio.set_event_loop(loop) ret = request.param() ret.notify = mock.Mock() - return ret + return ret # type: ignore[no-any-return] -def test_init_process(worker) -> None: +def test_init_process(worker: base_worker.GunicornWebWorker) -> None: with mock.patch("aiohttp.worker.asyncio") as m_asyncio: try: worker.init_process() @@ -72,7 +72,9 @@ def test_init_process(worker) -> None: assert m_asyncio.set_event_loop.called -def test_run(worker, loop) -> None: +def test_run( + worker: base_worker.GunicornWebWorker, loop: asyncio.AbstractEventLoop +) -> None: worker.log = mock.Mock() worker.cfg = mock.Mock() worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT @@ -86,7 +88,9 @@ def test_run(worker, loop) -> None: assert loop.is_closed() -def test_run_async_factory(worker, loop) -> None: +def test_run_async_factory( + worker: base_worker.GunicornWebWorker, loop: asyncio.AbstractEventLoop +) -> None: worker.log = mock.Mock() worker.cfg = mock.Mock() worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT @@ -94,8 +98,8 @@ def test_run_async_factory(worker, loop) -> None: worker.sockets = [] app = worker.wsgi - async def make_app(): - return app + async def make_app() -> web.Application: + return app # type: ignore[no-any-return] worker.wsgi = make_app @@ -107,7 +111,9 @@ async def make_app(): assert loop.is_closed() -def test_run_not_app(worker, loop) -> None: +def test_run_not_app( + worker: base_worker.GunicornWebWorker, loop: asyncio.AbstractEventLoop +) -> None: worker.log = mock.Mock() worker.cfg = mock.Mock() worker.cfg.access_log_format = ACCEPTABLE_LOG_FORMAT @@ -121,32 +127,24 @@ def test_run_not_app(worker, loop) -> None: assert loop.is_closed() -def test_handle_quit(worker, loop) -> None: - worker.loop = mock.Mock() - worker.handle_quit(object(), object()) - assert not worker.alive - assert worker.exit_code == 0 - worker.loop.call_later.asset_called_with(0.1, worker._notify_waiter_done) - - -def test_handle_abort(worker) -> None: +def test_handle_abort(worker: base_worker.GunicornWebWorker) -> None: with mock.patch("aiohttp.worker.sys") as m_sys: - worker.handle_abort(object(), object()) + worker.handle_abort(0, None) assert not worker.alive assert worker.exit_code == 1 m_sys.exit.assert_called_with(1) -def test__wait_next_notify(worker) -> None: - worker.loop = mock.Mock() - worker._notify_waiter_done = mock.Mock() - fut = worker._wait_next_notify() +def test__wait_next_notify(worker: base_worker.GunicornWebWorker) -> None: + worker.loop = mloop = mock.create_autospec(asyncio.AbstractEventLoop) + with mock.patch.object(worker, "_notify_waiter_done", autospec=True): + fut = worker._wait_next_notify() - assert worker._notify_waiter == fut - worker.loop.call_later.assert_called_with(1.0, worker._notify_waiter_done, fut) + assert worker._notify_waiter == fut + mloop.call_later.assert_called_with(1.0, worker._notify_waiter_done, fut) -def test__notify_waiter_done(worker) -> None: +def test__notify_waiter_done(worker: base_worker.GunicornWebWorker) -> None: worker._notify_waiter = None worker._notify_waiter_done() assert worker._notify_waiter is None @@ -159,7 +157,9 @@ def test__notify_waiter_done(worker) -> None: waiter.set_result.assert_called_with(True) -def test__notify_waiter_done_explicit_waiter(worker) -> None: +def test__notify_waiter_done_explicit_waiter( + worker: base_worker.GunicornWebWorker, +) -> None: worker._notify_waiter = None assert worker._notify_waiter is None @@ -173,7 +173,7 @@ def test__notify_waiter_done_explicit_waiter(worker) -> None: assert not waiter2.set_result.called -def test_init_signals(worker) -> None: +def test_init_signals(worker: base_worker.GunicornWebWorker) -> None: worker.loop = mock.Mock() worker.init_signals() assert worker.loop.add_signal_handler.called @@ -189,19 +189,23 @@ def test_init_signals(worker) -> None: ), ], ) -def test__get_valid_log_format_ok(worker, source, result) -> None: +def test__get_valid_log_format_ok( + worker: base_worker.GunicornWebWorker, source: str, result: str +) -> None: assert result == worker._get_valid_log_format(source) -def test__get_valid_log_format_exc(worker) -> None: +def test__get_valid_log_format_exc(worker: base_worker.GunicornWebWorker) -> None: with pytest.raises(ValueError) as exc: worker._get_valid_log_format(WRONG_LOG_FORMAT) assert "%(name)s" in str(exc.value) -async def test__run_ok_parent_changed(worker, loop, aiohttp_unused_port) -> None: - skip_if_no_dict(loop) - +async def test__run_ok_parent_changed( + worker: base_worker.GunicornWebWorker, + loop: asyncio.AbstractEventLoop, + aiohttp_unused_port: Callable[[], int], +) -> None: worker.ppid = 0 worker.alive = True sock = socket.socket() @@ -220,9 +224,11 @@ async def test__run_ok_parent_changed(worker, loop, aiohttp_unused_port) -> None worker.log.info.assert_called_with("Parent changed, shutting down: %s", worker) -async def test__run_exc(worker, loop, aiohttp_unused_port) -> None: - skip_if_no_dict(loop) - +async def test__run_exc( + worker: base_worker.GunicornWebWorker, + loop: asyncio.AbstractEventLoop, + aiohttp_unused_port: Callable[[], int], +) -> None: worker.ppid = os.getppid() worker.alive = True sock = socket.socket() @@ -235,9 +241,10 @@ async def test__run_exc(worker, loop, aiohttp_unused_port) -> None: worker.cfg.max_requests = 0 worker.cfg.is_ssl = False - def raiser(): + def raiser() -> None: waiter = worker._notify_waiter worker.alive = False + assert waiter is not None waiter.set_exception(RuntimeError()) loop.call_later(0.1, raiser) @@ -247,8 +254,8 @@ def raiser(): def test__create_ssl_context_without_certs_and_ciphers( - worker, - tls_certificate_pem_path, + worker: base_worker.GunicornWebWorker, + tls_certificate_pem_path: str, ) -> None: worker.cfg.ssl_version = ssl.PROTOCOL_TLS_CLIENT worker.cfg.cert_reqs = ssl.CERT_OPTIONAL @@ -261,8 +268,8 @@ def test__create_ssl_context_without_certs_and_ciphers( def test__create_ssl_context_with_ciphers( - worker, - tls_certificate_pem_path, + worker: base_worker.GunicornWebWorker, + tls_certificate_pem_path: str, ) -> None: worker.cfg.ssl_version = ssl.PROTOCOL_TLS_CLIENT worker.cfg.cert_reqs = ssl.CERT_OPTIONAL @@ -275,9 +282,9 @@ def test__create_ssl_context_with_ciphers( def test__create_ssl_context_with_ca_certs( - worker, - tls_ca_certificate_pem_path, - tls_certificate_pem_path, + worker: base_worker.GunicornWebWorker, + tls_ca_certificate_pem_path: str, + tls_certificate_pem_path: str, ) -> None: worker.cfg.ssl_version = ssl.PROTOCOL_TLS_CLIENT worker.cfg.cert_reqs = ssl.CERT_OPTIONAL diff --git a/tools/bench-asyncio-write.py b/tools/bench-asyncio-write.py index 5ae347172fb..3c35f295a58 100644 --- a/tools/bench-asyncio-write.py +++ b/tools/bench-asyncio-write.py @@ -123,10 +123,11 @@ async def bench(job_title, w, body, base=None): base = await bench(t, writes[0], c) for w in writes[1:]: await bench("", w, c, base) - with open("bench.md", "w") as f: - for line in res: - f.write("| {} |\n".format(" | ".join(line))) + return res loop = asyncio.get_event_loop() -loop.run_until_complete(main(loop)) +results = loop.run_until_complete(main(loop)) +with open("bench.md", "w") as f: + for line in results: + f.write("| {} |\n".format(" | ".join(line))) diff --git a/vendor/README.rst b/vendor/README.rst new file mode 100644 index 00000000000..1f10e20cce2 --- /dev/null +++ b/vendor/README.rst @@ -0,0 +1,23 @@ +LLHTTP +------ + +When building aiohttp from source, there is a pure Python parser used by default. +For better performance, you may want to build the higher performance C parser. + +To build this ``llhttp`` parser, first get/update the submodules (to update to a +newer release, add ``--remote`` and check the branch in ``.gitmodules``):: + + git submodule update --init --recursive + +Then build ``llhttp``:: + + cd vendor/llhttp/ + npm install + make + +Then build our parser:: + + cd - + make cythonize + +Then you can build or install it with ``python -m build`` or ``pip install -e .`` diff --git a/vendor/llhttp b/vendor/llhttp index 69d6db20085..7e18596bae8 160000 --- a/vendor/llhttp +++ b/vendor/llhttp @@ -1 +1 @@ -Subproject commit 69d6db2008508489d19267a0dcab30602b16fc5b +Subproject commit 7e18596bae8f63692ded9d3250d5d984fe90dcfb