Skip to content

Commit accc46b

Browse files
committed
bulk_inference: fix tests on pypy
The bulk_inference code is now multithreaded. For this reason, the trick of returning different values in the Mock based on the order of the calls no longer works. This was somewhat accidentally working on CPython, but not on pypy.
1 parent e9403e7 commit accc46b

File tree

1 file changed

+31
-9
lines changed

1 file changed

+31
-9
lines changed

tests/sap/aibus/dar/client/test_inference_client.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# mypy cannot deal with some of the monkey-patching we do below.
44
# https://github.com/python/mypy/issues/2427
55
from typing import Optional
6-
from unittest.mock import call
6+
from unittest.mock import call, Mock
77

88
import pytest
99
from requests import RequestException, Timeout
@@ -267,20 +267,41 @@ def test_bulk_inference_error(self, inference_client: InferenceClient):
267267

268268
exception_404 = DARHTTPException.create_from_response(url, response_404)
269269

270+
# The old trick to return different values in a Mock based on the call order
271+
# does not work here because the code is concurrent. Instead, we use a different
272+
# objectId for those objects where we want the request to fail
273+
def make_mock_post(exc):
274+
def post_to_endpoint(*args, **kwargs):
275+
payload = kwargs.pop("payload")
276+
object_id = payload["objects"][0]["objectId"]
277+
if object_id == "expected-to-fail":
278+
raise exc
279+
elif object_id == "b5cbcb34-7ab9-4da5-b7ec-654c90757eb9":
280+
response = Mock()
281+
response.json.return_value = self.inference_response(
282+
len(payload["objects"])
283+
)
284+
return response
285+
else:
286+
raise ValueError(f"objectId '{object_id}' not handled in test.")
287+
288+
return post_to_endpoint
289+
290+
# Try different exceptions
270291
exceptions = [
271292
exception_404,
272293
RequestException("Request Error"),
273294
Timeout("Timeout"),
274295
]
275-
# Try different exceptions
276296
for exc in exceptions:
277-
inference_client.session.post_to_endpoint.return_value.json.side_effect = [
278-
self.inference_response(50),
279-
exc,
280-
self.inference_response(40),
281-
]
297+
inference_client.session.post_to_endpoint.side_effect = make_mock_post(exc)
282298

283-
many_objects = [self.objects()[0] for _ in range(50 + 50 + 40)]
299+
many_objects = []
300+
many_objects.extend([self.objects()[0] for _ in range(50)])
301+
many_objects.extend(
302+
[self.objects(object_id="expected-to-fail")[0] for _ in range(50)]
303+
)
304+
many_objects.extend([self.objects()[0] for _ in range(40)])
284305
assert len(many_objects) == 50 + 50 + 40
285306

286307
response = inference_client.do_bulk_inference(
@@ -290,7 +311,7 @@ def test_bulk_inference_error(self, inference_client: InferenceClient):
290311
)
291312

292313
expected_error_response = {
293-
"objectId": "b5cbcb34-7ab9-4da5-b7ec-654c90757eb9",
314+
"objectId": "expected-to-fail",
294315
"labels": None,
295316
# If this test fails, I found it can make pytest/PyCharm hang because it
296317
# takes too much time in difflib.
@@ -302,6 +323,7 @@ def test_bulk_inference_error(self, inference_client: InferenceClient):
302323
expected_response.extend(expected_error_response for _ in range(50))
303324
expected_response.extend(self.inference_response(40)["predictions"])
304325

326+
assert len(response) == len(expected_response)
305327
assert response == expected_response
306328

307329
def test_bulk_inference_error_no_object_ids(

0 commit comments

Comments
 (0)