diff --git a/starlette_opentracing/middleware.py b/starlette_opentracing/middleware.py index cb6b85d..8ba5e61 100644 --- a/starlette_opentracing/middleware.py +++ b/starlette_opentracing/middleware.py @@ -4,13 +4,29 @@ from opentracing import InvalidCarrierException, SpanContextCorruptedException from opentracing.ext import tags from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request class StarletteTracingMiddleWare(BaseHTTPMiddleware): - def __init__(self, app, tracer): + def __init__(self, app, tracer, use_template: bool = False): # Todo: add choice between global tracer and tracer that is already configured super().__init__(app) self._tracer = tracer + self.use_template = use_template + + def get_template(self, request: Request) -> str: + """Get the template for the route endpoint.""" + method = request.method + urls = [ + route + for route in request.scope["router"].routes + if hasattr(route, "endpoint") and + "endpoint" in request.scope and + route.endpoint == request.scope["endpoint"] + ] + template = urls[0].path if len(urls) > 0 else "UNKNOWN" + method_path = method + " " + template + return method_path async def dispatch(self, request, call_next): span_ctx = None @@ -38,4 +54,8 @@ async def dispatch(self, request, call_next): span.set_tag(tags.HTTP_URL, url) response = await call_next(request) + + if self.use_template: + operation_name = self.get_template(request) + span.set_operation_name(operation_name) return response diff --git a/tests/test_tracer.py b/tests/test_tracer.py index c2025bf..5672ebd 100644 --- a/tests/test_tracer.py +++ b/tests/test_tracer.py @@ -55,3 +55,21 @@ def foo(request): # Todo: more asserts; still not sure if we should have 3 finished spans in the external tracer spans = external_tracer.finished_spans() assert len(spans) == 1 + + +def test_tracer_uses_path_templates_for_operation_names(): + app = Starlette() + mocked_tracer = MockTracer(scope_manager=ContextVarsScopeManager()) + app.add_middleware( + StarletteTracingMiddleWare, tracer=mocked_tracer, use_template=True + ) + + @app.route("/foo/{foo_id}") + def foo(foo_id: str): + return PlainTextResponse(f"Foo: {foo_id}") + + client = TestClient(app) + client.get("/foo/MyFoo") + spans = mocked_tracer.finished_spans() + assert len(spans) == 1 + assert spans[0].operation_name == "GET /foo/{foo_id}"