Skip to content

Commit

Permalink
fix: allow mock https subrequest mocks. aiohttp_client compatibility
Browse files Browse the repository at this point in the history
The `is_ssl` method was monkey-patched in a synchronous way that couldn't work when simultaneous requests needed it both patched and not-patched.  Instead of undoing the ClientRequest monkey-patching we create DirectClientRequest and have our passthrough session use it instead.

For compatibility between aresponses and aiohttp_client fixtures we use a work-around recommended by @serjshevchenko that replaces the aiohttp_client loop fixture with the event loop provided by pytest-asycnio.  The hanging was caused by different fixtures running on different event loops.
  • Loading branch information
brycedrennan committed Jan 3, 2021
1 parent 5bb3d1d commit 6e5e346
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 25 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ import aiohttp
import pytest
import aresponses


@pytest.mark.asyncio
async def test_foo(event_loop):
async with aresponses.ResponsesMockServer(loop=event_loop) as arsps:
Expand Down Expand Up @@ -213,6 +214,15 @@ async def aresponses(loop):
yield server
```

If you're trying to use the `aiohttp_client` test fixture then you'll need to mock out the aiohttp `loop` fixture
instead:
```python
@pytest.fixture
def loop(event_loop):
"""replace aiohttp loop fixture with pytest-asyncio fixture"""
return event_loop
```

## Contributing

### Dev environment setup
Expand Down
38 changes: 21 additions & 17 deletions aresponses/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import logging
from copy import copy
from functools import partial
from typing import List, NamedTuple

import pytest
Expand Down Expand Up @@ -38,7 +37,7 @@ async def _start(self, request, *_, **__):
writer = self._payload_writer = request._payload_writer
return writer

async def write_eof(self, *_, **__):
async def write_eof(self, *_, **__): # noqa
await super().write_eof(self._body)


Expand Down Expand Up @@ -183,24 +182,28 @@ async def _find_response(self, request):

async def passthrough(self, request):
"""Make non-mocked network request"""
connector = TCPConnector()
connector._resolve_host = partial(self._old_resolver_mock, connector)

new_is_ssl = ClientRequest.is_ssl
ClientRequest.is_ssl = self._old_is_ssl
try:
original_request = request.clone(scheme="https" if request.headers["AResponsesIsSSL"] else "http")
class DirectTcpConnector(TCPConnector):
def _resolve_host(slf, *args, **kwargs): # noqa
return self._old_resolver_mock(slf, *args, **kwargs)

class DirectClientRequest(ClientRequest):
def is_ssl(slf) -> bool:
return slf._aresponses_direct_is_ssl()

connector = DirectTcpConnector()

original_request = request.clone(scheme="https" if request.headers["AResponsesIsSSL"] else "http")

headers = {k: v for k, v in request.headers.items() if k != "AResponsesIsSSL"}
headers = {k: v for k, v in request.headers.items() if k != "AResponsesIsSSL"}

async with ClientSession(connector=connector) as session:
async with getattr(session, request.method.lower())(original_request.url, headers=headers, data=(await request.read())) as r:
headers = {k: v for k, v in r.headers.items() if k.lower() == "content-type"}
data = await r.read()
response = self.Response(body=data, status=r.status, headers=headers)
return response
finally:
ClientRequest.is_ssl = new_is_ssl
async with ClientSession(connector=connector, request_class=DirectClientRequest) as session:
request_method = getattr(session, request.method.lower())
async with request_method(original_request.url, headers=headers, data=(await request.read())) as r:
headers = {k: v for k, v in r.headers.items() if k.lower() == "content-type"}
data = await r.read()
response = self.Response(body=data, status=r.status, headers=headers)
return response

async def __aenter__(self) -> "ResponsesMockServer":
await self.start_server(loop=self._loop)
Expand All @@ -213,6 +216,7 @@ async def _resolver_mock(_self, host, port, traces=None):
TCPConnector._resolve_host = _resolver_mock

self._old_is_ssl = ClientRequest.is_ssl
ClientRequest._aresponses_direct_is_ssl = ClientRequest.is_ssl

def new_is_ssl(_self):
return False
Expand Down
2 changes: 1 addition & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math
import re

import pytest
import aiohttp
import pytest


@pytest.mark.asyncio
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import pytest
from aiohttp import web

from aresponses import ResponsesMockServer

@pytest.fixture
def loop(event_loop):
"""replace aiohttp loop fixture with pytest-asyncio fixture"""
return event_loop


def make_app():
Expand All @@ -28,18 +32,14 @@ async def get_ip_address(protocol):
return ip


@pytest.fixture(name="aresponses")
async def aresponses_fixture(loop):
async with ResponsesMockServer(loop=loop) as server:
yield server


@pytest.mark.asyncio
async def test_app_simple_endpoint(aiohttp_client):
client = await aiohttp_client(make_app())
r = await client.get("/constant")
assert (await r.text()) == "42"


@pytest.mark.asyncio
async def test_app_simple_endpoint_with_aresponses(aiohttp_client, aresponses):
"""
when testing your own aiohttp server you must setup passthrough to it
Expand All @@ -54,6 +54,7 @@ async def test_app_simple_endpoint_with_aresponses(aiohttp_client, aresponses):
assert (await r.text()) == "42"


@pytest.mark.asyncio
@pytest.mark.parametrize("protocol", ["http", "https"])
async def test_app_with_subrequest_using_aresponses(aiohttp_client, aresponses, protocol):
"""
Expand All @@ -67,3 +68,4 @@ async def test_app_with_subrequest_using_aresponses(aiohttp_client, aresponses,
body = await r.text()
assert r.status == 200, body
assert "ip is" in (await r.text())
aresponses.assert_plan_strictly_followed()

0 comments on commit 6e5e346

Please sign in to comment.