From 51ecd411b3ce163016896f2a76412af9209e8986 Mon Sep 17 00:00:00 2001 From: Mohamad Zamini <32536264+mzamini92@users.noreply.github.com> Date: Tue, 18 Apr 2023 16:39:14 -0600 Subject: [PATCH 1/2] Update text_client_utils.py This implementation adds a new get_available_models() method to the DebugClient class, which retrieves the list of available model configurations from the API and returns a list of their names. The send_message() method then calls this method and checks if the provided model_config_name is in the list of available models. If it's not, a ValueError is raised with an appropriate error message. --- inference/text-client/text_client_utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/inference/text-client/text_client_utils.py b/inference/text-client/text_client_utils.py index 71073ff684..0d45ac1901 100644 --- a/inference/text-client/text_client_utils.py +++ b/inference/text-client/text_client_utils.py @@ -28,7 +28,18 @@ def create_chat(self): self.message_id = None return self.chat_id + def get_available_models(self): + response = self.http_client.get( + f"{self.backend_url}/models", + headers=self.auth_headers, + ) + response.raise_for_status() + return [model["name"] for model in response.json()] + def send_message(self, message, model_config_name): + available_models = self.get_available_models() + if model_config_name not in available_models: + raise ValueError(f"Invalid model config name: {model_config_name}") response = self.http_client.post( f"{self.backend_url}/chats/{self.chat_id}/prompter_message", json={ From 373d7ffaaf5c253081d0ea8071bf9ffd5d143f77 Mon Sep 17 00:00:00 2001 From: Mohamad Zamini <32536264+mzamini92@users.noreply.github.com> Date: Tue, 18 Apr 2023 23:21:10 -0600 Subject: [PATCH 2/2] Update text_client_utils.py we can modify the `__init__` method to include the line `self.available_models = self.get_available_models()` so that the list of available models is retrieved only once during the instantiation of the DebugClient object: --- inference/text-client/text_client_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/inference/text-client/text_client_utils.py b/inference/text-client/text_client_utils.py index 0d45ac1901..1e8c2363f9 100644 --- a/inference/text-client/text_client_utils.py +++ b/inference/text-client/text_client_utils.py @@ -9,6 +9,7 @@ class DebugClient: def __init__(self, backend_url, http_client=requests): self.backend_url = backend_url self.http_client = http_client + self.available_models = self.get_available_models() def login(self, username): auth_data = self.http_client.get(f"{self.backend_url}/auth/callback/debug", params={"code": username}).json()