33# mypy cannot deal with some of the monkey-patching we do below.
44# https://github.com/python/mypy/issues/2427
55from typing import Optional
6- from unittest .mock import call
6+ from unittest .mock import call , Mock
77
88import pytest
99from requests import RequestException , Timeout
@@ -267,20 +267,41 @@ def test_bulk_inference_error(self, inference_client: InferenceClient):
267267
268268 exception_404 = DARHTTPException .create_from_response (url , response_404 )
269269
270+ # The old trick to return different values in a Mock based on the call order
271+ # does not work here because the code is concurrent. Instead, we use a different
272+ # objectId for those objects where we want the request to fail
273+ def make_mock_post (exc ):
274+ def post_to_endpoint (* args , ** kwargs ):
275+ payload = kwargs .pop ("payload" )
276+ object_id = payload ["objects" ][0 ]["objectId" ]
277+ if object_id == "expected-to-fail" :
278+ raise exc
279+ elif object_id == "b5cbcb34-7ab9-4da5-b7ec-654c90757eb9" :
280+ response = Mock ()
281+ response .json .return_value = self .inference_response (
282+ len (payload ["objects" ])
283+ )
284+ return response
285+ else :
286+ raise ValueError (f"objectId '{ object_id } ' not handled in test." )
287+
288+ return post_to_endpoint
289+
290+ # Try different exceptions
270291 exceptions = [
271292 exception_404 ,
272293 RequestException ("Request Error" ),
273294 Timeout ("Timeout" ),
274295 ]
275- # Try different exceptions
276296 for exc in exceptions :
277- inference_client .session .post_to_endpoint .return_value .json .side_effect = [
278- self .inference_response (50 ),
279- exc ,
280- self .inference_response (40 ),
281- ]
297+ inference_client .session .post_to_endpoint .side_effect = make_mock_post (exc )
282298
283- many_objects = [self .objects ()[0 ] for _ in range (50 + 50 + 40 )]
299+ many_objects = []
300+ many_objects .extend ([self .objects ()[0 ] for _ in range (50 )])
301+ many_objects .extend (
302+ [self .objects (object_id = "expected-to-fail" )[0 ] for _ in range (50 )]
303+ )
304+ many_objects .extend ([self .objects ()[0 ] for _ in range (40 )])
284305 assert len (many_objects ) == 50 + 50 + 40
285306
286307 response = inference_client .do_bulk_inference (
@@ -290,7 +311,7 @@ def test_bulk_inference_error(self, inference_client: InferenceClient):
290311 )
291312
292313 expected_error_response = {
293- "objectId" : "b5cbcb34-7ab9-4da5-b7ec-654c90757eb9 " ,
314+ "objectId" : "expected-to-fail " ,
294315 "labels" : None ,
295316 # If this test fails, I found it can make pytest/PyCharm hang because it
296317 # takes too much time in difflib.
@@ -302,6 +323,7 @@ def test_bulk_inference_error(self, inference_client: InferenceClient):
302323 expected_response .extend (expected_error_response for _ in range (50 ))
303324 expected_response .extend (self .inference_response (40 )["predictions" ])
304325
326+ assert len (response ) == len (expected_response )
305327 assert response == expected_response
306328
307329 def test_bulk_inference_error_no_object_ids (
0 commit comments