diff --git a/ddtrace/contrib/aioredis/patch.py b/ddtrace/contrib/aioredis/patch.py index c2752835c9f..286e65c0053 100644 --- a/ddtrace/contrib/aioredis/patch.py +++ b/ddtrace/contrib/aioredis/patch.py @@ -70,8 +70,8 @@ async def traced_execute_command(func, instance, args, kwargs): return await func(*args, **kwargs) -async def traced_pipeline(func, instance, args, kwargs): - pipeline = await func(*args, **kwargs) +def traced_pipeline(func, instance, args, kwargs): + pipeline = func(*args, **kwargs) pin = Pin.get_from(instance) if pin: pin.onto(pipeline) diff --git a/releasenotes/notes/fix-aioredis-async-with-pipeline-805966300810edf8.yaml b/releasenotes/notes/fix-aioredis-async-with-pipeline-805966300810edf8.yaml new file mode 100644 index 00000000000..e09f30b640a --- /dev/null +++ b/releasenotes/notes/fix-aioredis-async-with-pipeline-805966300810edf8.yaml @@ -0,0 +1,4 @@ +--- +fixes: + - | + Fixes incompatibility of wrapped aioredis pipelines in ``async with`` statements. diff --git a/tests/contrib/aioredis/test_aioredis.py b/tests/contrib/aioredis/test_aioredis.py index 33e736c2345..b62d74f35dd 100644 --- a/tests/contrib/aioredis/test_aioredis.py +++ b/tests/contrib/aioredis/test_aioredis.py @@ -153,6 +153,35 @@ async def test_pipeline_traced(redis_client): assert response_list[3].decode() == "bar" +@pytest.mark.skipif(aioredis_version < (2, 0), reason="only supported in aioredis >= 2.0") +@pytest.mark.asyncio +@pytest.mark.snapshot +async def test_pipeline_traced_context_manager_transaction(redis_client): + """ + Regression test for: https://github.com/DataDog/dd-trace-py/issues/3106 + + https://aioredis.readthedocs.io/en/latest/migration/#pipelines-and-transactions-multiexec + + Example:: + + async def main(): + redis = await aioredis.from_url("redis://localhost") + async with redis.pipeline(transaction=True) as pipe: + ok1, ok2 = await (pipe.set("key1", "value1").set("key2", "value2").execute()) + assert ok1 + assert ok2 + """ + + async with redis_client.pipeline(transaction=True) as p: + set_1, set_2, get_1, get_2 = await (p.set("blah", "boo").set("foo", "bar").get("blah").get("foo").execute()) + + # response from redis.set is OK if successfully pushed + assert set_1 is True + assert set_2 is True + assert get_1.decode() == "boo" + assert get_2.decode() == "bar" + + @pytest.mark.asyncio @pytest.mark.snapshot(variants={"": aioredis_version >= (2, 0), "13": aioredis_version < (2, 0)}) async def test_two_traced_pipelines(redis_client): diff --git a/tests/snapshots/tests.contrib.aioredis.test_aioredis.test_pipeline_traced_context_manager_transaction.json b/tests/snapshots/tests.contrib.aioredis.test_aioredis.test_pipeline_traced_context_manager_transaction.json new file mode 100644 index 00000000000..b969f3520fb --- /dev/null +++ b/tests/snapshots/tests.contrib.aioredis.test_aioredis.test_pipeline_traced_context_manager_transaction.json @@ -0,0 +1,28 @@ +[[ + { + "name": "redis.command", + "service": "redis", + "resource": "SET blah boo\nSET foo bar\nGET blah\nGET foo", + "trace_id": 0, + "span_id": 1, + "parent_id": 0, + "type": "redis", + "meta": { + "out.host": "127.0.0.1", + "redis.raw_command": "SET blah boo\nSET foo bar\nGET blah\nGET foo", + "runtime-id": "b734eb991b1f45f2b063db6d3c5623b9" + }, + "metrics": { + "_dd.agent_psr": 1.0, + "_dd.measured": 1, + "_dd.top_level": 1, + "_dd.tracer_kr": 1.0, + "_sampling_priority_v1": 1, + "out.port": 6379, + "out.redis_db": 0, + "redis.pipeline_length": 4, + "system.pid": 28312 + }, + "duration": 2132000, + "start": 1641496497488785000 + }]]