diff --git a/service_streamer/service_streamer.py b/service_streamer/service_streamer.py index 7a7fd91..8096122 100644 --- a/service_streamer/service_streamer.py +++ b/service_streamer/service_streamer.py @@ -254,7 +254,7 @@ def _send_response(self, client_id, task_id, request_id, model_output): class Streamer(_BaseStreamer): def __init__(self, predict_function_or_model, batch_size, max_latency=0.1, worker_num=1, - cuda_devices=None, model_init_args=None, model_init_kwargs=None): + cuda_devices=None, model_init_args=None, model_init_kwargs=None, wait_for_worker_ready=False): super().__init__() self.worker_num = worker_num self.cuda_devices = cuda_devices @@ -267,6 +267,8 @@ def __init__(self, predict_function_or_model, batch_size, max_latency=0.1, worke self._worker_ready_events = [] self._worker_destroy_events = [] self._setup_gpu_worker() + if wait_for_worker_ready: + self._wait_for_worker_ready() self._delay_setup() def _setup_gpu_worker(self):