Skip to content

Commit

Permalink
Merge pull request #57 from CircleUp/aiohttp-compat
Browse files Browse the repository at this point in the history
fix: https passthrough. aiohttp_client compatability
  • Loading branch information
brycedrennan committed Jan 5, 2021
2 parents bc32af0 + b1c6031 commit 21799c9
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 22 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ import aiohttp
import pytest
import aresponses


@pytest.mark.asyncio
async def test_foo(event_loop):
async with aresponses.ResponsesMockServer(loop=event_loop) as arsps:
Expand Down Expand Up @@ -213,6 +214,15 @@ async def aresponses(loop):
yield server
```

If you're trying to use the `aiohttp_client` test fixture then you'll need to mock out the aiohttp `loop` fixture
instead:
```python
@pytest.fixture
def loop(event_loop):
"""replace aiohttp loop fixture with pytest-asyncio fixture"""
return event_loop
```

## Contributing

### Dev environment setup
Expand Down
53 changes: 32 additions & 21 deletions aresponses/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import logging
import math
import re
from copy import copy
from functools import partial
from typing import List, NamedTuple

import pytest
Expand Down Expand Up @@ -38,7 +39,7 @@ async def _start(self, request, *_, **__):
writer = self._payload_writer = request._payload_writer
return writer

async def write_eof(self, *_, **__):
async def write_eof(self, *_, **__): # noqa
await super().write_eof(self._body)


Expand Down Expand Up @@ -92,6 +93,8 @@ class ResponsesMockServer(BaseTestServer):
ANY = ANY
Response = web.Response
RawResponse = RawResponse
INFINITY = math.inf
LOCALHOST = re.compile(r"127\.0\.0\.1:?\d{0,5}")

def __init__(self, *, scheme=sentinel, host="127.0.0.1", **kwargs):
self._responses = []
Expand Down Expand Up @@ -149,6 +152,9 @@ def add(self, host_pattern=ANY, path_pattern=ANY, method_pattern=ANY, response="

self._responses.append((route, response))

def add_local_passthrough(self, repeat=INFINITY):
self.add(host_pattern=self.LOCALHOST, repeat=repeat, response=self.passthrough)

async def _find_response(self, request):
for i, (route, response) in enumerate(self._responses):

Expand Down Expand Up @@ -183,24 +189,28 @@ async def _find_response(self, request):

async def passthrough(self, request):
"""Make non-mocked network request"""
connector = TCPConnector()
connector._resolve_host = partial(self._old_resolver_mock, connector)

new_is_ssl = ClientRequest.is_ssl
ClientRequest.is_ssl = self._old_is_ssl
try:
original_request = request.clone(scheme="https" if request.headers["AResponsesIsSSL"] else "http")
class DirectTcpConnector(TCPConnector):
def _resolve_host(slf, *args, **kwargs): # noqa
return self._old_resolver_mock(slf, *args, **kwargs)

class DirectClientRequest(ClientRequest):
def is_ssl(slf) -> bool:
return slf._aresponses_direct_is_ssl()

connector = DirectTcpConnector()

original_request = request.clone(scheme="https" if request.headers["AResponsesIsSSL"] else "http")

headers = {k: v for k, v in request.headers.items() if k != "AResponsesIsSSL"}
headers = {k: v for k, v in request.headers.items() if k != "AResponsesIsSSL"}

async with ClientSession(connector=connector) as session:
async with getattr(session, request.method.lower())(original_request.url, headers=headers, data=(await request.read())) as r:
headers = {k: v for k, v in r.headers.items() if k.lower() == "content-type"}
data = await r.read()
response = self.Response(body=data, status=r.status, headers=headers)
return response
finally:
ClientRequest.is_ssl = new_is_ssl
async with ClientSession(connector=connector, request_class=DirectClientRequest) as session:
request_method = getattr(session, request.method.lower())
async with request_method(original_request.url, headers=headers, data=(await request.read())) as r:
headers = {k: v for k, v in r.headers.items() if k.lower() == "content-type"}
data = await r.read()
response = self.Response(body=data, status=r.status, headers=headers)
return response

async def __aenter__(self) -> "ResponsesMockServer":
await self.start_server(loop=self._loop)
Expand All @@ -213,6 +223,7 @@ async def _resolver_mock(_self, host, port, traces=None):
TCPConnector._resolve_host = _resolver_mock

self._old_is_ssl = ClientRequest.is_ssl
ClientRequest._aresponses_direct_is_ssl = ClientRequest.is_ssl

def new_is_ssl(_self):
return False
Expand All @@ -238,10 +249,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):

await self.close()

def assert_no_unused_routes(self):
if self._responses:
route, _ = self._responses[0]
raise UnusedRouteError(f"Unused Route: {route}")
def assert_no_unused_routes(self, ignore_infinite_repeats=False):
for route, _ in self._responses:
if not ignore_infinite_repeats or route.repeat != self.INFINITY:
raise UnusedRouteError(f"Unused Route: {route}")

def assert_called_in_order(self):
if self._first_unordered_route is not None:
Expand Down
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from aresponses import aresponses

assert aresponses

pytest_plugins = "aiohttp.pytest_plugin"
2 changes: 1 addition & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math
import re

import pytest
import aiohttp
import pytest


@pytest.mark.asyncio
Expand Down
71 changes: 71 additions & 0 deletions tests/test_with_aiohttp_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import aiohttp
import pytest
from aiohttp import web


@pytest.fixture
def loop(event_loop):
"""replace aiohttp loop fixture with pytest-asyncio fixture"""
return event_loop


def make_app():
app = web.Application()

async def constant_handler(request):
return web.Response(text="42")

async def ip_handler(request):
protocol = request.query["protocol"]
ip = await get_ip_address(protocol=protocol)
return web.Response(text=f"ip is {ip}")

app.add_routes([web.get("/constant", constant_handler)])
app.add_routes([web.get("/ip", ip_handler)])
return app


async def get_ip_address(protocol):
async with aiohttp.ClientSession() as s:
async with s.get(f"{protocol}://httpbin.org/ip") as resp:
ip = (await resp.json())["origin"]
return ip


@pytest.mark.asyncio
async def test_app_simple_endpoint(aiohttp_client):
client = await aiohttp_client(make_app())
r = await client.get("/constant")
assert (await r.text()) == "42"


@pytest.mark.asyncio
async def test_app_simple_endpoint_with_aresponses(aiohttp_client, aresponses):
"""
when testing your own aiohttp server you must setup passthrough to it
Ideally this wouldn't be necessary but haven't figured that out yet.
Perhaps all local calls should be passthrough.
"""
aresponses.add("127.0.0.1:4241", response=aresponses.passthrough)

client = await aiohttp_client(make_app(), server_kwargs={"port": 4241})
r = await client.get("/constant")
assert (await r.text()) == "42"


@pytest.mark.asyncio
@pytest.mark.parametrize("protocol", ["http", "https"])
async def test_app_with_subrequest_using_aresponses(aiohttp_client, aresponses, protocol):
"""
but passthrough doesn't work if the handler itself makes an aiohttp https request
"""
aresponses.add_local_passthrough(repeat=1)
aresponses.add("httpbin.org", response={"origin": "1.2.3.4"})

client = await aiohttp_client(make_app())
r = await client.get(f"/ip?protocol={protocol}")
body = await r.text()
assert r.status == 200, body
assert "ip is" in (await r.text())
aresponses.assert_plan_strictly_followed()

0 comments on commit 21799c9

Please sign in to comment.