Skip to content

Commit 264db6f

Browse files
committed
bulk_inference: ensure results are returned in correct order
The `ThreadPoolProcessor.map` method does not guarantee that results are returned in the correct order. This PR sorts the results correctly. This behavior was evident only on pypy builds, but the API contract of not guaranteeing in-order results is present on both cPython and Pypy.
1 parent e9403e7 commit 264db6f

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

sap/aibus/dar/client/inference_client.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Client API for the Inference microservice.
33
"""
44
from concurrent.futures import ThreadPoolExecutor
5-
from typing import List, Union
5+
from typing import List, Union, Tuple
66

77
from requests import RequestException
88

@@ -154,12 +154,13 @@ def do_bulk_inference(
154154
:return: the aggregated ObjectPrediction dictionaries
155155
"""
156156

157-
def predict_call(work_package):
157+
def predict_call(work_package: Tuple[int, list]) -> Tuple[int, list]:
158+
work_package_index, objects_list = work_package
158159
try:
159160
response = self.create_inference_request(
160-
model_name, work_package, top_n=top_n, retry=retry
161+
model_name, objects_list, top_n=top_n, retry=retry
161162
)
162-
return response["predictions"]
163+
return (work_package_index, response["predictions"])
163164
except (DARHTTPException, RequestException) as exc:
164165
self.log.warning(
165166
"Caught %s during bulk inference. "
@@ -174,20 +175,25 @@ def predict_call(work_package):
174175
"labels": None,
175176
"_sdk_error": "{}: {}".format(exc.__class__.__name__, str(exc)),
176177
}
177-
for inference_object in work_package
178+
for inference_object in objects_list
178179
]
179-
return prediction_error
180+
return (work_package_index, prediction_error)
180181

181-
results = []
182+
# Because Executor.map may return results out of order, we add an index
183+
# to each list of objects and later restore the correct order
184+
input_data_indexed = enumerate(split_list(objects, LIMIT_OBJECTS_PER_CALL))
182185

186+
results_buffer = []
183187
with ThreadPoolExecutor(max_workers=4) as pool:
184-
results_iterator = pool.map(
185-
predict_call, split_list(objects, LIMIT_OBJECTS_PER_CALL)
186-
)
187-
188+
results_iterator = pool.map(predict_call, input_data_indexed)
188189
for predictions in results_iterator:
189-
results.extend(predictions)
190+
results_buffer.append(predictions)
190191

192+
# sort by index and remove index
193+
results_buffer.sort(key=lambda x: x[0])
194+
results = []
195+
for result in results_buffer:
196+
results.extend(result[1])
191197
return results
192198

193199
def create_inference_request_with_url(

0 commit comments

Comments
 (0)