diff --git a/jina/orchestrate/deployments/__init__.py b/jina/orchestrate/deployments/__init__.py index c444caecf1142..ddc3747aac373 100644 --- a/jina/orchestrate/deployments/__init__.py +++ b/jina/orchestrate/deployments/__init__.py @@ -494,7 +494,7 @@ def __init__( ): raise RuntimeError( f'It is not supported to have {ProtocolType.WEBSOCKET.to_string()} deployment for ' - f'Deployments with more than one shard' + f'Deployments' ) is_mac_os = platform.system() == 'Darwin' is_windows_os = platform.system() == 'Windows' diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 77dcb68b926a1..6468dd39f096b 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -165,12 +165,18 @@ def _http_fastapi_default_app(self, def call_handle(request): return self.process_single_data(request, None) - return get_fastapi_app( + app = get_fastapi_app( request_models_map=request_models_map, caller=call_handle, **kwargs ) + @app.on_event('shutdown') + async def _shutdown(): + await self.close() + + return app + async def _hot_reload(self): import inspect @@ -793,12 +799,14 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: if inner_dict['input']['model'].schema() == legacy_doc_schema: inner_dict['input']['model'] = legacy_doc_schema else: - inner_dict['input']['model'] = _create_aux_model_doc_list_to_list(inner_dict['input']['model']).schema() + inner_dict['input']['model'] = _create_aux_model_doc_list_to_list( + inner_dict['input']['model']).schema() if inner_dict['output']['model'].schema() == legacy_doc_schema: inner_dict['output']['model'] = legacy_doc_schema else: - inner_dict['output']['model'] = _create_aux_model_doc_list_to_list(inner_dict['output']['model']).schema() + inner_dict['output']['model'] = _create_aux_model_doc_list_to_list( + inner_dict['output']['model']).schema() else: for endpoint_name, inner_dict in schemas.items(): inner_dict['input']['model'] = inner_dict['input']['model'].schema() diff --git a/tests/integration/docarray_v2/test_v2.py b/tests/integration/docarray_v2/test_v2.py index 49bffb5247054..ef11c5e451a1a 100644 --- a/tests/integration/docarray_v2/test_v2.py +++ b/tests/integration/docarray_v2/test_v2.py @@ -990,7 +990,6 @@ def search(self, docs: DocList[TextDocWithId], **kwargs) -> DocList[ResultTestDo def test_issue_shards_missmatch_endpoint(): - class MyDoc(BaseDoc): text: str embedding: NdArray[128] @@ -1014,3 +1013,25 @@ def foo(self, docs: DocList[MyDoc], **kwargs) -> DocList[MyDocWithMatchesAndScor with d: res = d.post(on='/', inputs=DocList[MyDoc]([MyDoc(text='hey ha', embedding=np.random.rand(128))])) assert len(res) == 1 + + +@pytest.mark.parametrize('protocol', ['grpc', 'http']) +def test_closing_executor(tmpdir, protocol): + class ClosingExec(Executor): + + def __init__(self, file_path, *args, **kwargs): + super().__init__(*args, **kwargs) + self._file_path = file_path + + def close(self) -> None: + with open(self._file_path, 'w') as f: + f.write('I closed') + + file_path = f'{str(tmpdir)}/file.txt' + d = Deployment(uses=ClosingExec, uses_with={'file_path': file_path}, protocol=protocol) + with d: + pass + + with open(file_path, 'r') as f: + r = f.read() + assert r == 'I closed'