diff --git a/.stats.yml b/.stats.yml index 48f88c4..de0574c 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1 +1 @@ -configured_endpoints: 43 +configured_endpoints: 44 diff --git a/api.md b/api.md index 30db665..6becc1e 100644 --- a/api.md +++ b/api.md @@ -174,6 +174,7 @@ from dataherald.types import NlGenerationListResponse Methods: +- client.nl_generations.create(\*\*params) -> NlGenerationResponse - client.nl_generations.retrieve(id) -> NlGenerationResponse - client.nl_generations.list(\*\*params) -> NlGenerationListResponse diff --git a/src/dataherald/resources/nl_generations.py b/src/dataherald/resources/nl_generations.py index be38bee..d1d3b4c 100644 --- a/src/dataherald/resources/nl_generations.py +++ b/src/dataherald/resources/nl_generations.py @@ -4,7 +4,7 @@ import httpx -from ..types import NlGenerationListResponse, nl_generation_list_params +from ..types import NlGenerationListResponse, nl_generation_list_params, nl_generation_create_params from .._types import ( NOT_GIVEN, Body, @@ -29,6 +29,47 @@ class NlGenerations(SyncAPIResource): def with_raw_response(self) -> NlGenerationsWithRawResponse: return NlGenerationsWithRawResponse(self) + def create( + self, + *, + sql_generation: nl_generation_create_params.SqlGeneration, + max_rows: int | NotGiven = NOT_GIVEN, + metadata: object | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> NlGenerationResponse: + """ + Create Prompt Sql Nl Generation + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return self._post( + "/api/prompts/sql-generations/nl-generations", + body=maybe_transform( + { + "sql_generation": sql_generation, + "max_rows": max_rows, + "metadata": metadata, + }, + nl_generation_create_params.NlGenerationCreateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NlGenerationResponse, + ) + def retrieve( self, id: str, @@ -112,6 +153,47 @@ class AsyncNlGenerations(AsyncAPIResource): def with_raw_response(self) -> AsyncNlGenerationsWithRawResponse: return AsyncNlGenerationsWithRawResponse(self) + async def create( + self, + *, + sql_generation: nl_generation_create_params.SqlGeneration, + max_rows: int | NotGiven = NOT_GIVEN, + metadata: object | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> NlGenerationResponse: + """ + Create Prompt Sql Nl Generation + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + return await self._post( + "/api/prompts/sql-generations/nl-generations", + body=maybe_transform( + { + "sql_generation": sql_generation, + "max_rows": max_rows, + "metadata": metadata, + }, + nl_generation_create_params.NlGenerationCreateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NlGenerationResponse, + ) + async def retrieve( self, id: str, @@ -192,6 +274,9 @@ async def list( class NlGenerationsWithRawResponse: def __init__(self, nl_generations: NlGenerations) -> None: + self.create = to_raw_response_wrapper( + nl_generations.create, + ) self.retrieve = to_raw_response_wrapper( nl_generations.retrieve, ) @@ -202,6 +287,9 @@ def __init__(self, nl_generations: NlGenerations) -> None: class AsyncNlGenerationsWithRawResponse: def __init__(self, nl_generations: AsyncNlGenerations) -> None: + self.create = async_to_raw_response_wrapper( + nl_generations.create, + ) self.retrieve = async_to_raw_response_wrapper( nl_generations.retrieve, ) diff --git a/src/dataherald/types/__init__.py b/src/dataherald/types/__init__.py index 69e70b7..9b59f05 100644 --- a/src/dataherald/types/__init__.py +++ b/src/dataherald/types/__init__.py @@ -33,6 +33,7 @@ from .golden_sql_upload_response import GoldenSqlUploadResponse as GoldenSqlUploadResponse from .sql_generation_list_params import SqlGenerationListParams as SqlGenerationListParams from .table_description_response import TableDescriptionResponse as TableDescriptionResponse +from .nl_generation_create_params import NlGenerationCreateParams as NlGenerationCreateParams from .nl_generation_list_response import NlGenerationListResponse as NlGenerationListResponse from .sql_generation_create_params import SqlGenerationCreateParams as SqlGenerationCreateParams from .sql_generation_list_response import SqlGenerationListResponse as SqlGenerationListResponse diff --git a/src/dataherald/types/nl_generation_create_params.py b/src/dataherald/types/nl_generation_create_params.py new file mode 100644 index 0000000..c8c997b --- /dev/null +++ b/src/dataherald/types/nl_generation_create_params.py @@ -0,0 +1,35 @@ +# File generated from our OpenAPI spec by Stainless. + +from __future__ import annotations + +from typing_extensions import Required, TypedDict + +__all__ = ["NlGenerationCreateParams", "SqlGeneration", "SqlGenerationPrompt"] + + +class NlGenerationCreateParams(TypedDict, total=False): + sql_generation: Required[SqlGeneration] + + max_rows: int + + metadata: object + + +class SqlGenerationPrompt(TypedDict, total=False): + db_connection_id: Required[str] + + text: Required[str] + + metadata: object + + +class SqlGeneration(TypedDict, total=False): + prompt: Required[SqlGenerationPrompt] + + evaluate: bool + + finetuning_id: str + + metadata: object + + sql: str diff --git a/tests/api_resources/test_nl_generations.py b/tests/api_resources/test_nl_generations.py index 8ec86de..90b17f8 100644 --- a/tests/api_resources/test_nl_generations.py +++ b/tests/api_resources/test_nl_generations.py @@ -21,6 +21,51 @@ class TestNlGenerations: loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + @parametrize + def test_method_create(self, client: Dataherald) -> None: + nl_generation = client.nl_generations.create( + sql_generation={ + "prompt": { + "text": "string", + "db_connection_id": "string", + } + }, + ) + assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + + @parametrize + def test_method_create_with_all_params(self, client: Dataherald) -> None: + nl_generation = client.nl_generations.create( + sql_generation={ + "finetuning_id": "string", + "evaluate": True, + "sql": "string", + "metadata": {}, + "prompt": { + "text": "string", + "db_connection_id": "string", + "metadata": {}, + }, + }, + max_rows=0, + metadata={}, + ) + assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + + @parametrize + def test_raw_response_create(self, client: Dataherald) -> None: + response = client.nl_generations.with_raw_response.create( + sql_generation={ + "prompt": { + "text": "string", + "db_connection_id": "string", + } + }, + ) + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + nl_generation = response.parse() + assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + @parametrize def test_method_retrieve(self, client: Dataherald) -> None: nl_generation = client.nl_generations.retrieve( @@ -65,6 +110,51 @@ class TestAsyncNlGenerations: loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + @parametrize + async def test_method_create(self, client: AsyncDataherald) -> None: + nl_generation = await client.nl_generations.create( + sql_generation={ + "prompt": { + "text": "string", + "db_connection_id": "string", + } + }, + ) + assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + + @parametrize + async def test_method_create_with_all_params(self, client: AsyncDataherald) -> None: + nl_generation = await client.nl_generations.create( + sql_generation={ + "finetuning_id": "string", + "evaluate": True, + "sql": "string", + "metadata": {}, + "prompt": { + "text": "string", + "db_connection_id": "string", + "metadata": {}, + }, + }, + max_rows=0, + metadata={}, + ) + assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + + @parametrize + async def test_raw_response_create(self, client: AsyncDataherald) -> None: + response = await client.nl_generations.with_raw_response.create( + sql_generation={ + "prompt": { + "text": "string", + "db_connection_id": "string", + } + }, + ) + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + nl_generation = response.parse() + assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + @parametrize async def test_method_retrieve(self, client: AsyncDataherald) -> None: nl_generation = await client.nl_generations.retrieve(