From 274e65d04879eeb0f33f1e6c900d61664e36445f Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Sat, 27 May 2023 20:39:55 +0100 Subject: [PATCH] Added bearer token header to worker http client (for HF API) (#3240) --- inference/worker/__main__.py | 1 + inference/worker/settings.py | 3 +++ inference/worker/utils.py | 13 ++++++++++++- 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index 36afc09133..569e340276 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -40,6 +40,7 @@ def main(): base_url=settings.inference_server_url, basic_auth_username=settings.basic_auth_username, basic_auth_password=settings.basic_auth_password, + bearer_token=settings.bearer_token, ) while True: diff --git a/inference/worker/settings.py b/inference/worker/settings.py index 724303ad2f..b499a401fa 100644 --- a/inference/worker/settings.py +++ b/inference/worker/settings.py @@ -5,6 +5,7 @@ class Settings(pydantic.BaseSettings): backend_url: str = "ws://localhost:8000" model_config_name: str = "distilgpt2" inference_server_url: str = "http://localhost:8001" + inference_server_route: str = "/generate_stream" safety_server_url: str = "http://localhost:8002" api_key: str = "0000" @@ -21,6 +22,8 @@ class Settings(pydantic.BaseSettings): # for hf basic server quantize: bool = False + bearer_token: str | None = None + basic_auth_username: str | None = None basic_auth_password: str | None = None diff --git a/inference/worker/utils.py b/inference/worker/utils.py index e7d0b9b578..c3528e4f1a 100644 --- a/inference/worker/utils.py +++ b/inference/worker/utils.py @@ -172,6 +172,7 @@ class HttpClient(pydantic.BaseModel): base_url: str basic_auth_username: str | None = None basic_auth_password: str | None = None + bearer_token: str | None = None @property def auth(self): @@ -180,10 +181,19 @@ def auth(self): else: return None + def _maybe_add_bearer_token(self, headers: dict[str, str] | None): + if self.bearer_token: + if headers is None: + headers = {} + headers["Authorization"] = f"Bearer {self.bearer_token}" + return headers + def get(self, path: str, **kwargs): + kwargs["headers"] = self._maybe_add_bearer_token(kwargs.get("headers")) return requests.get(self.base_url + path, auth=self.auth, **kwargs) def post(self, path: str, **kwargs): + kwargs["headers"] = self._maybe_add_bearer_token(kwargs.get("headers")) return requests.post(self.base_url + path, auth=self.auth, **kwargs) @@ -192,9 +202,10 @@ def get_inference_server_stream_events(request: interface.GenerateStreamRequest) base_url=settings.inference_server_url, basic_auth_username=settings.basic_auth_username, basic_auth_password=settings.basic_auth_password, + bearer_token=settings.bearer_token, ) response = http.post( - "/generate_stream", + settings.inference_server_route, json=request.dict(), stream=True, headers={"Accept": "text/event-stream"},