From 39f37f4508e9328a51bc63dae9f99fbdce7e12be Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Thu, 6 Jun 2024 12:19:25 +0300 Subject: [PATCH 1/2] fix: Support base urls --- ai21/ai21_http_client.py | 19 ++++++++++------ ai21/clients/studio/ai21_client.py | 4 +++- .../studio/resources/chat/chat_completions.py | 3 +-- .../clients/studio/resources/studio_answer.py | 4 +--- ai21/clients/studio/resources/studio_chat.py | 4 ++-- .../studio/resources/studio_completion.py | 8 +++---- .../studio/resources/studio_custom_model.py | 9 +++----- .../studio/resources/studio_dataset.py | 10 +++------ ai21/clients/studio/resources/studio_embed.py | 3 +-- ai21/clients/studio/resources/studio_gec.py | 3 +-- .../studio/resources/studio_improvements.py | 3 +-- .../studio/resources/studio_library.py | 22 ++++++------------- .../studio/resources/studio_paraphrase.py | 3 +-- .../studio/resources/studio_resource.py | 22 ++++++++++--------- .../studio/resources/studio_segmentation.py | 3 +-- .../studio/resources/studio_summarize.py | 3 +-- .../resources/studio_summarize_by_segment.py | 3 +-- ai21/http_client.py | 9 +++++--- ai21/services/sagemaker.py | 10 ++++----- 19 files changed, 66 insertions(+), 79 deletions(-) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 1a3b334f..87e786d2 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -14,7 +14,7 @@ def __init__( *, api_key: Optional[str] = None, requires_api_key: bool = True, - api_host: Optional[str] = None, + base_url: Optional[str] = None, api_version: Optional[str] = None, headers: Optional[Dict[str, Any]] = None, timeout_sec: Optional[int] = None, @@ -27,7 +27,7 @@ def __init__( if requires_api_key and not self._api_key: raise MissingApiKeyError() - self._api_host = api_host + self._base_url = base_url self._api_version = api_version self._headers = headers self._timeout_sec = timeout_sec @@ -76,12 +76,17 @@ def _build_user_agent(self) -> str: def execute_http_request( self, method: str, - url: str, + path: str, params: Optional[Dict] = None, + body: Optional[Dict] = None, stream: bool = False, files: Optional[Dict[str, BinaryIO]] = None, ) -> httpx.Response: - return self._http_client.execute_http_request(method=method, url=url, params=params, files=files, stream=stream) - - def get_base_url(self) -> str: - return f"{self._api_host}/studio/{self._api_version}" + return self._http_client.execute_http_request( + method=method, + url=f"{self._base_url}{path}", + params=params, + files=files, + stream=stream, + body=body, + ) diff --git a/ai21/clients/studio/ai21_client.py b/ai21/clients/studio/ai21_client.py index 3d29575a..96ff722a 100644 --- a/ai21/clients/studio/ai21_client.py +++ b/ai21/clients/studio/ai21_client.py @@ -43,9 +43,11 @@ def __init__( env_config: _AI21EnvConfig = AI21EnvConfig, **kwargs, ): + base_url = api_host or env_config.api_host + self._http_client = AI21HTTPClient( api_key=api_key or env_config.api_key, - api_host=api_host or env_config.api_host, + base_url=f"{base_url}/studio/v1", api_version=env_config.api_version, headers=headers, timeout_sec=timeout_sec or env_config.timeout_sec, diff --git a/ai21/clients/studio/resources/chat/chat_completions.py b/ai21/clients/studio/resources/chat/chat_completions.py index cddc8cb3..f03062e8 100644 --- a/ai21/clients/studio/resources/chat/chat_completions.py +++ b/ai21/clients/studio/resources/chat/chat_completions.py @@ -75,9 +75,8 @@ def create( **kwargs, ) - url = f"{self._client.get_base_url()}/{self._module_name}" return self._post( - url=url, + path=f"/{self._module_name}", body=body, stream=stream or False, stream_cls=Stream[ChatCompletionChunk], diff --git a/ai21/clients/studio/resources/studio_answer.py b/ai21/clients/studio/resources/studio_answer.py index 0da4ae01..0dda93fe 100644 --- a/ai21/clients/studio/resources/studio_answer.py +++ b/ai21/clients/studio/resources/studio_answer.py @@ -10,8 +10,6 @@ def create( question: str, **kwargs, ) -> AnswerResponse: - url = f"{self._client.get_base_url()}/{self._module_name}" - body = self._create_body(context=context, question=question, **kwargs) - return self._post(url=url, body=body, response_cls=AnswerResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=AnswerResponse) diff --git a/ai21/clients/studio/resources/studio_chat.py b/ai21/clients/studio/resources/studio_chat.py index daccea1d..de8fa45c 100644 --- a/ai21/clients/studio/resources/studio_chat.py +++ b/ai21/clients/studio/resources/studio_chat.py @@ -48,8 +48,8 @@ def create( count_penalty=count_penalty, **kwargs, ) - url = f"{self._client.get_base_url()}/{model}/{self._module_name}" - return self._post(url=url, body=body, response_cls=ChatResponse) + + return self._post(path=f"/{model}/{self._module_name}", body=body, response_cls=ChatResponse) @property def completions(self) -> ChatCompletions: diff --git a/ai21/clients/studio/resources/studio_completion.py b/ai21/clients/studio/resources/studio_completion.py index d2b9cdc7..faaf8879 100644 --- a/ai21/clients/studio/resources/studio_completion.py +++ b/ai21/clients/studio/resources/studio_completion.py @@ -29,12 +29,12 @@ def create( logit_bias: Dict[str, float] | NotGiven = NOT_GIVEN, **kwargs, ) -> CompletionsResponse: - url = f"{self._client.get_base_url()}/{model}" + path = f"/{model}" if custom_model: - url = f"{url}/{custom_model}" + path = f"{path}/{custom_model}" - url = f"{url}/{self._module_name}" + path = f"{path}/{self._module_name}" body = self._create_body( model=model, prompt=prompt, @@ -53,4 +53,4 @@ def create( logit_bias=logit_bias, **kwargs, ) - return self._post(url=url, body=body, response_cls=CompletionsResponse) + return self._post(path=path, body=body, response_cls=CompletionsResponse) diff --git a/ai21/clients/studio/resources/studio_custom_model.py b/ai21/clients/studio/resources/studio_custom_model.py index 61e351d1..41a45172 100644 --- a/ai21/clients/studio/resources/studio_custom_model.py +++ b/ai21/clients/studio/resources/studio_custom_model.py @@ -16,7 +16,6 @@ def create( num_epochs: Optional[int] = None, **kwargs, ) -> None: - url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body( dataset_id=dataset_id, model_name=model_name, @@ -25,12 +24,10 @@ def create( num_epochs=num_epochs, **kwargs, ) - self._post(url=url, body=body, response_cls=None) + self._post(path=f"/{self._module_name}", body=body, response_cls=None) def list(self) -> List[CustomBaseModelResponse]: - url = f"{self._client.get_base_url()}/{self._module_name}" - return self._get(url=url, response_cls=List[CustomBaseModelResponse]) + return self._get(path=f"/{self._module_name}", response_cls=List[CustomBaseModelResponse]) def get(self, resource_id: str) -> CustomBaseModelResponse: - url = f"{self._client.get_base_url()}/{self._module_name}/{resource_id}" - return self._get(url=url, response_cls=CustomBaseModelResponse) + return self._get(path=f"/{self._module_name}/{resource_id}", response_cls=CustomBaseModelResponse) diff --git a/ai21/clients/studio/resources/studio_dataset.py b/ai21/clients/studio/resources/studio_dataset.py index 10f26ddd..1fcb3f29 100644 --- a/ai21/clients/studio/resources/studio_dataset.py +++ b/ai21/clients/studio/resources/studio_dataset.py @@ -27,17 +27,13 @@ def create( **kwargs, ) return self._post( - url=self._base_url(), + path=f"/{self._module_name}", body=body, files=files, ) def list(self) -> List[DatasetResponse]: - return self._get(url=self._base_url(), response_cls=List[DatasetResponse]) + return self._get(path=f"/{self._module_name}", response_cls=List[DatasetResponse]) def get(self, dataset_pid: str) -> DatasetResponse: - url = f"{self._base_url()}/{dataset_pid}" - return self._get(url=url, response_cls=DatasetResponse) - - def _base_url(self) -> str: - return f"{self._client.get_base_url()}/{self._module_name}" + return self._get(path=f"/{self._module_name}/{dataset_pid}", response_cls=DatasetResponse) diff --git a/ai21/clients/studio/resources/studio_embed.py b/ai21/clients/studio/resources/studio_embed.py index e45b6269..054902be 100644 --- a/ai21/clients/studio/resources/studio_embed.py +++ b/ai21/clients/studio/resources/studio_embed.py @@ -7,7 +7,6 @@ class StudioEmbed(StudioResource, Embed): def create(self, texts: List[str], type: Optional[EmbedType] = None, **kwargs) -> EmbedResponse: - url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(texts=texts, type=type, **kwargs) - return self._post(url=url, body=body, response_cls=EmbedResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=EmbedResponse) diff --git a/ai21/clients/studio/resources/studio_gec.py b/ai21/clients/studio/resources/studio_gec.py index a8752c9c..106df366 100644 --- a/ai21/clients/studio/resources/studio_gec.py +++ b/ai21/clients/studio/resources/studio_gec.py @@ -6,6 +6,5 @@ class StudioGEC(StudioResource, GEC): def create(self, text: str, **kwargs) -> GECResponse: body = self._create_body(text=text, **kwargs) - url = f"{self._client.get_base_url()}/{self._module_name}" - return self._post(url=url, body=body, response_cls=GECResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=GECResponse) diff --git a/ai21/clients/studio/resources/studio_improvements.py b/ai21/clients/studio/resources/studio_improvements.py index 88ea996c..1191db86 100644 --- a/ai21/clients/studio/resources/studio_improvements.py +++ b/ai21/clients/studio/resources/studio_improvements.py @@ -11,7 +11,6 @@ def create(self, text: str, types: List[ImprovementType], **kwargs) -> Improveme if len(types) == 0: raise EmptyMandatoryListError("types") - url = f"{self._client.get_base_url()}/{self._module_name}" body = self._create_body(text=text, types=types, **kwargs) - return self._post(url=url, body=body, response_cls=ImprovementsResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=ImprovementsResponse) diff --git a/ai21/clients/studio/resources/studio_library.py b/ai21/clients/studio/resources/studio_library.py index e177df94..08acbb4e 100644 --- a/ai21/clients/studio/resources/studio_library.py +++ b/ai21/clients/studio/resources/studio_library.py @@ -31,18 +31,15 @@ def create( public_url: Optional[str] | NotGiven = NOT_GIVEN, **kwargs, ) -> str: - url = f"{self._client.get_base_url()}/{self._module_name}" files = {"file": open(file_path, "rb")} body = remove_not_given({"path": path, "labels": labels, "publicUrl": public_url, **kwargs}) - raw_response = self._post(url=url, files=files, body=body, response_cls=dict) + raw_response = self._post(path=f"/{self._module_name}", files=files, body=body, response_cls=dict) return raw_response["fileId"] def get(self, file_id: str) -> FileResponse: - url = f"{self._client.get_base_url()}/{self._module_name}/{file_id}" - - return self._get(url=url, response_cls=FileResponse) + return self._get(path=f"/{self._module_name}/{file_id}", response_cls=FileResponse) def list( self, @@ -51,10 +48,9 @@ def list( limit: Optional[int] | NotGiven = NOT_GIVEN, **kwargs, ) -> List[FileResponse]: - url = f"{self._client.get_base_url()}/{self._module_name}" params = remove_not_given({"offset": offset, "limit": limit}) - return self._get(url=url, params=params, response_cls=List[FileResponse]) + return self._get(path=f"/{self._module_name}", params=params, response_cls=List[FileResponse]) def update( self, @@ -64,7 +60,6 @@ def update( labels: Optional[List[str]] | NotGiven = NOT_GIVEN, **kwargs, ) -> None: - url = f"{self._client.get_base_url()}/{self._module_name}/{file_id}" body = remove_not_given( { "publicUrl": public_url, @@ -72,11 +67,10 @@ def update( **kwargs, } ) - self._put(url=url, body=body) + self._put(path=f"/{self._module_name}/{file_id}", body=body) def delete(self, file_id: str) -> None: - url = f"{self._client.get_base_url()}/{self._module_name}/{file_id}" - self._delete(url=url) + self._delete(path=f"/{self._module_name}/{file_id}") class LibrarySearch(StudioResource): @@ -91,7 +85,6 @@ def create( max_segments: Optional[int] | NotGiven = NOT_GIVEN, **kwargs, ) -> LibrarySearchResponse: - url = f"{self._client.get_base_url()}/{self._module_name}" body = remove_not_given( { "query": query, @@ -102,7 +95,7 @@ def create( } ) - return self._post(url=url, body=body, response_cls=LibrarySearchResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=LibrarySearchResponse) class LibraryAnswer(StudioResource): @@ -117,7 +110,6 @@ def create( labels: Optional[List[str]] | NotGiven = NOT_GIVEN, **kwargs, ) -> LibraryAnswerResponse: - url = f"{self._client.get_base_url()}/{self._module_name}" body = remove_not_given( { "question": question, @@ -128,4 +120,4 @@ def create( } ) - return self._post(url=url, body=body, response_cls=LibraryAnswerResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=LibraryAnswerResponse) diff --git a/ai21/clients/studio/resources/studio_paraphrase.py b/ai21/clients/studio/resources/studio_paraphrase.py index 25764e46..a4f829a8 100644 --- a/ai21/clients/studio/resources/studio_paraphrase.py +++ b/ai21/clients/studio/resources/studio_paraphrase.py @@ -22,6 +22,5 @@ def create( end_index=end_index, **kwargs, ) - url = f"{self._client.get_base_url()}/{self._module_name}" - return self._post(url=url, body=body, response_cls=ParaphraseResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=ParaphraseResponse) diff --git a/ai21/clients/studio/resources/studio_resource.py b/ai21/clients/studio/resources/studio_resource.py index 8f4be991..8abc4d5e 100644 --- a/ai21/clients/studio/resources/studio_resource.py +++ b/ai21/clients/studio/resources/studio_resource.py @@ -18,8 +18,9 @@ def __init__(self, client: AI21HTTPClient): def _post( self, - url: str, - body: Dict[str, Any], + path: str, + body: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, response_cls: Optional[ResponseT] = None, stream_cls: Optional[StreamT] = None, stream: bool = False, @@ -27,30 +28,31 @@ def _post( ) -> ResponseT | StreamT: response = self._client.execute_http_request( method="POST", - url=url, + path=path, stream=stream, - params=body or {}, + body=body or {}, + params=params or {}, files=files, ) return self._cast_response(stream=stream, response=response, response_cls=response_cls, stream_cls=stream_cls) def _get( - self, url: str, response_cls: Optional[ResponseT] = None, params: Optional[Dict[str, Any]] = None + self, path: str, response_cls: Optional[ResponseT] = None, params: Optional[Dict[str, Any]] = None ) -> ResponseT | StreamT: - response = self._client.execute_http_request(method="GET", url=url, params=params or {}) + response = self._client.execute_http_request(method="GET", path=path, params=params or {}) return self._cast_response(response=response, response_cls=response_cls) def _put( - self, url: str, response_cls: Optional[ResponseT] = None, body: Dict[str, Any] = None + self, path: str, response_cls: Optional[ResponseT] = None, body: Dict[str, Any] = None ) -> ResponseT | StreamT: - response = self._client.execute_http_request(method="PUT", url=url, params=body or {}) + response = self._client.execute_http_request(method="PUT", path=path, body=body or {}) return self._cast_response(response=response, response_cls=response_cls) - def _delete(self, url: str, response_cls: Optional[ResponseT] = None) -> ResponseT | StreamT: + def _delete(self, path: str, response_cls: Optional[ResponseT] = None) -> ResponseT | StreamT: response = self._client.execute_http_request( method="DELETE", - url=url, + path=path, ) return self._cast_response(response=response, response_cls=response_cls) diff --git a/ai21/clients/studio/resources/studio_segmentation.py b/ai21/clients/studio/resources/studio_segmentation.py index a2aee960..94041a74 100644 --- a/ai21/clients/studio/resources/studio_segmentation.py +++ b/ai21/clients/studio/resources/studio_segmentation.py @@ -6,6 +6,5 @@ class StudioSegmentation(StudioResource, Segmentation): def create(self, source: str, source_type: DocumentType, **kwargs) -> SegmentationResponse: body = self._create_body(source=source, source_type=source_type.value, **kwargs) - url = f"{self._client.get_base_url()}/{self._module_name}" - return self._post(url=url, body=body, response_cls=SegmentationResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=SegmentationResponse) diff --git a/ai21/clients/studio/resources/studio_summarize.py b/ai21/clients/studio/resources/studio_summarize.py index 6ba4f9fe..aeefc80f 100644 --- a/ai21/clients/studio/resources/studio_summarize.py +++ b/ai21/clients/studio/resources/studio_summarize.py @@ -24,6 +24,5 @@ def create( summary_method=summary_method, **kwargs, ) - url = f"{self._client.get_base_url()}/{self._module_name}" - return self._post(url=url, body=body, response_cls=SummarizeResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=SummarizeResponse) diff --git a/ai21/clients/studio/resources/studio_summarize_by_segment.py b/ai21/clients/studio/resources/studio_summarize_by_segment.py index abb1705e..93ae09f6 100644 --- a/ai21/clients/studio/resources/studio_summarize_by_segment.py +++ b/ai21/clients/studio/resources/studio_summarize_by_segment.py @@ -15,6 +15,5 @@ def create( focus=focus, **kwargs, ) - url = f"{self._client.get_base_url()}/{self._module_name}" - return self._post(url=url, body=body, response_cls=SummarizeBySegmentResponse) + return self._post(path=f"/{self._module_name}", body=body, response_cls=SummarizeBySegmentResponse) diff --git a/ai21/http_client.py b/ai21/http_client.py index 834869fa..1a26248f 100644 --- a/ai21/http_client.py +++ b/ai21/http_client.py @@ -76,11 +76,12 @@ def execute_http_request( method: str, url: str, params: Optional[Dict] = None, + body: Optional[Dict] = None, stream: bool = False, files: Optional[Dict[str, BinaryIO]] = None, ) -> httpx.Response: try: - response = self._request(files=files, method=method, params=params, url=url, stream=stream) + response = self._request(files=files, method=method, params=params, url=url, stream=stream, body=body) except RetryError as retry_error: last_attempt = retry_error.last_attempt @@ -107,6 +108,7 @@ def _request( files: Optional[Dict[str, BinaryIO]], method: str, params: Optional[Dict], + body: Optional[Dict], url: str, stream: bool, ) -> httpx.Response: @@ -130,14 +132,15 @@ def _request( headers.pop( "Content-Type" ) # multipart/form-data 'Content-Type' is being added when passing rb files and payload - data = params + data = body else: - data = json.dumps(params).encode() if params else None + data = json.dumps(body).encode() if body else None request = self._client.build_request( method=method, url=url, headers=headers, + params=params, data=data, timeout=timeout, files=files, diff --git a/ai21/services/sagemaker.py b/ai21/services/sagemaker.py index 3132cd1a..3b72b294 100644 --- a/ai21/services/sagemaker.py +++ b/ai21/services/sagemaker.py @@ -23,8 +23,8 @@ def get_model_package_arn(cls, model_name: str, region: str, version: str = LATE response = client.execute_http_request( method="POST", - url=f"{client.get_base_url()}/{_GET_ARN_ENDPOINT}", - params={ + path=f"/{_GET_ARN_ENDPOINT}", + body={ "modelName": model_name, "region": region, "version": version, @@ -46,8 +46,8 @@ def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: response = client.execute_http_request( method="POST", - url=f"{client.get_base_url()}/{_LIST_VERSIONS_ENDPOINT}", - params={ + path=f"/{_LIST_VERSIONS_ENDPOINT}", + body={ "modelName": model_name, "region": region, }, @@ -59,7 +59,7 @@ def list_model_package_versions(cls, model_name: str, region: str) -> List[str]: def _create_ai21_http_client(cls) -> AI21HTTPClient: return AI21HTTPClient( api_key=AI21EnvConfig.api_key, - api_host=AI21EnvConfig.api_host, + base_url=f"{AI21EnvConfig.api_host}/studio/v1", requires_api_key=False, api_version=AI21EnvConfig.api_version, timeout_sec=AI21EnvConfig.timeout_sec, From c5bf9e1827ae1324348a50ed5298ef9be1e651ea Mon Sep 17 00:00:00 2001 From: Josephasafg Date: Thu, 6 Jun 2024 13:09:25 +0300 Subject: [PATCH 2/2] fix: tests --- ai21/ai21_http_client.py | 4 +- .../studio/resources/test_studio_resources.py | 12 +++--- tests/unittests/test_ai21_http_client.py | 40 ++++++++----------- 3 files changed, 25 insertions(+), 31 deletions(-) diff --git a/ai21/ai21_http_client.py b/ai21/ai21_http_client.py index 87e786d2..804d9947 100644 --- a/ai21/ai21_http_client.py +++ b/ai21/ai21_http_client.py @@ -85,8 +85,8 @@ def execute_http_request( return self._http_client.execute_http_request( method=method, url=f"{self._base_url}{path}", - params=params, + params=params or {}, files=files, stream=stream, - body=body, + body=body or {}, ) diff --git a/tests/unittests/clients/studio/resources/test_studio_resources.py b/tests/unittests/clients/studio/resources/test_studio_resources.py index f5ad3196..e3a32859 100644 --- a/tests/unittests/clients/studio/resources/test_studio_resources.py +++ b/tests/unittests/clients/studio/resources/test_studio_resources.py @@ -77,7 +77,6 @@ def test__create__should_return_response( mock_ai21_studio_client: AI21HTTPClient, ): mock_ai21_studio_client.execute_http_request.return_value = expected_httpx_response - mock_ai21_studio_client.get_base_url.return_value = _BASE_URL resource = studio_resource(mock_ai21_studio_client) @@ -88,8 +87,9 @@ def test__create__should_return_response( assert actual_response == expected_response mock_ai21_studio_client.execute_http_request.assert_called_with( method="POST", - url=f"{_BASE_URL}/{url_suffix}", - params=expected_body, + path=f"/{url_suffix}", + body=expected_body, + params={}, stream=False, files=None, ) @@ -103,7 +103,6 @@ def test__create__when_pass_kwargs__should_pass_to_request( mock_successful_httpx_response.json.return_value = expected_answer.to_dict() mock_ai21_studio_client.execute_http_request.return_value = mock_successful_httpx_response - mock_ai21_studio_client.get_base_url.return_value = _BASE_URL studio_answer = StudioAnswer(mock_ai21_studio_client) studio_answer.create( @@ -114,12 +113,13 @@ def test__create__when_pass_kwargs__should_pass_to_request( mock_ai21_studio_client.execute_http_request.assert_called_with( method="POST", - url=_BASE_URL + "/answer", - params={ + path="/answer", + body={ "context": _DUMMY_CONTEXT, "question": _DUMMY_QUESTION, "some_dummy_kwargs": "some_dummy_value", }, + params={}, stream=False, files=None, ) diff --git a/tests/unittests/test_ai21_http_client.py b/tests/unittests/test_ai21_http_client.py index 3ae8f6c6..de7a6c19 100644 --- a/tests/unittests/test_ai21_http_client.py +++ b/tests/unittests/test_ai21_http_client.py @@ -64,20 +64,6 @@ def test__build_headers__when_pass_headers__should_append(): assert client._http_client._headers["Authorization"] == f"Bearer {_DUMMY_API_KEY}" -@pytest.mark.parametrize( - ids=[ - "when_api_host_is_set__should_return_set_value", - ], - argnames=["api_host", "expected_api_host"], - argvalues=[ - ("http://test_host", "http://test_host/studio/v1"), - ], -) -def test__get_base_url(api_host: Optional[str], expected_api_host: str): - client = AI21HTTPClient(api_key=_DUMMY_API_KEY, api_host=api_host, api_version="v1") - assert client.get_base_url() == expected_api_host - - @pytest.mark.parametrize( ids=[ "when_making_request__should_send_appropriate_parameters", @@ -85,12 +71,13 @@ def test__get_base_url(api_host: Optional[str], expected_api_host: str): ], argnames=["params", "headers"], argvalues=[ - ({"method": "GET", "url": "test_url", "params": {"foo": "bar"}}, _EXPECTED_GET_HEADERS), + ({"method": "GET", "path": "/test_url", "params": {"foo": "bar"}, "body": {}}, _EXPECTED_GET_HEADERS), ( { "method": "POST", - "url": "test_url", - "params": {"foo": "bar"}, + "path": "/test_url", + "body": {"foo": "bar"}, + "params": {}, "stream": False, "files": {"file": "test_file"}, }, @@ -110,7 +97,7 @@ def test__execute_http_request__( mock_httpx_client.send.return_value = MockResponse(response_json, 200) http_client = HttpClient(client=mock_httpx_client) - client = AI21HTTPClient(http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") + client = AI21HTTPClient(http_client=http_client, api_key=_DUMMY_API_KEY, base_url=dummy_api_host, api_version="v1") response = client.execute_http_request(**params) assert response.json() == response_json @@ -121,12 +108,19 @@ def test__execute_http_request__( timeout=300, headers=headers, files=params["files"], - data=params["params"], - url=params["url"], + data=params["body"], + params=params["params"], + url=f"{dummy_api_host}{params['path']}", method=params["method"], ) else: - mock_httpx_client.build_request.assert_called_once_with(timeout=300, headers=headers, **params) + mock_httpx_client.build_request.assert_called_once_with( + timeout=300, + headers=headers, + url=f"{dummy_api_host}{params['path']}", + params=params["params"], + method=params["method"], + ) mock_httpx_client.send.assert_called_once_with(request=mock_response, stream=False) @@ -137,9 +131,9 @@ def test__execute_http_request__when_files_with_put_method__should_raise_value_e ): response_json = {"test_key": "test_value"} http_client = HttpClient(client=mock_httpx_client) - client = AI21HTTPClient(http_client=http_client, api_key=_DUMMY_API_KEY, api_host=dummy_api_host, api_version="v1") + client = AI21HTTPClient(http_client=http_client, api_key=_DUMMY_API_KEY, base_url=dummy_api_host, api_version="v1") mock_httpx_client.request.return_value = MockResponse(response_json, 200) with pytest.raises(ValueError): - params = {"method": "PUT", "url": "test_url", "params": {"foo": "bar"}, "files": {"file": "test_file"}} + params = {"method": "PUT", "path": "test_url", "body": {"foo": "bar"}, "files": {"file": "test_file"}} client.execute_http_request(**params)