diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index ed09f9c56..a53c46885 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -621,20 +621,18 @@ async def log_request_details( ) -> Response: """Middleware to log all request's host, method, path, status and request and body""" + log = LOGGER.debug if request.url.path == "/healthz" else LOGGER.info request_body = await request.body() client = request.client or Address("Unknown", -1) log_message = f"{client.host}:{client.port} {request.method} {request.url.path}" extra = { "request_body": request_body, } - LOGGER.debug(log_message, extra=extra) + log(log_message, extra=extra) response = await call_next(request) log_message += f" {response.status_code}" - if request.url.path == "/healthz": - LOGGER.debug(log_message, extra=extra) - else: - LOGGER.info(log_message, extra=extra) + log(log_message, extra=extra) return response diff --git a/tests/unit_tests/service/test_main.py b/tests/unit_tests/service/test_main.py index 4a4bcca63..a7e04105c 100644 --- a/tests/unit_tests/service/test_main.py +++ b/tests/unit_tests/service/test_main.py @@ -1,5 +1,5 @@ from unittest import mock -from unittest.mock import Mock +from unittest.mock import Mock, call import pytest from fastapi import FastAPI, Request @@ -29,31 +29,36 @@ async def root(): assert response.headers["X-BlueAPI-VERSION"] == __version__ -async def test_log_request_details(): +@pytest.mark.parametrize("path,level", [("/", "info"), ("/healthz", "debug")]) +async def test_log_request_details(path: str, level: str): with mock.patch("blueapi.service.main.LOGGER") as logger: app = FastAPI() app.middleware("http")(log_request_details) - @app.post("/") + @app.post(path) async def root(): return {"message": "Hello World"} client = TestClient(app) - response = client.post("/", content="foo") + response = client.post(path, content="foo") assert response.status_code == 200 - logger.debug.assert_called_once_with( - "testclient:50000 POST /", - extra={ - "request_body": b"foo", - }, - ) - - logger.info.assert_called_once_with( - "testclient:50000 POST / 200", - extra={ - "request_body": b"foo", - }, + log_level = getattr(logger, level) + log_level.assert_has_calls( + [ + call( + f"testclient:50000 POST {path}", + extra={ + "request_body": b"foo", + }, + ), + call( + f"testclient:50000 POST {path} 200", + extra={ + "request_body": b"foo", + }, + ), + ] )