diff --git a/inference/text-client/text_client_utils.py b/inference/text-client/text_client_utils.py index 71073ff684..1e8c2363f9 100644 --- a/inference/text-client/text_client_utils.py +++ b/inference/text-client/text_client_utils.py @@ -9,6 +9,7 @@ class DebugClient: def __init__(self, backend_url, http_client=requests): self.backend_url = backend_url self.http_client = http_client + self.available_models = self.get_available_models() def login(self, username): auth_data = self.http_client.get(f"{self.backend_url}/auth/callback/debug", params={"code": username}).json() @@ -28,7 +29,18 @@ def create_chat(self): self.message_id = None return self.chat_id + def get_available_models(self): + response = self.http_client.get( + f"{self.backend_url}/models", + headers=self.auth_headers, + ) + response.raise_for_status() + return [model["name"] for model in response.json()] + def send_message(self, message, model_config_name): + available_models = self.get_available_models() + if model_config_name not in available_models: + raise ValueError(f"Invalid model config name: {model_config_name}") response = self.http_client.post( f"{self.backend_url}/chats/{self.chat_id}/prompter_message", json={