Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .stats.yml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
configured_endpoints: 43
configured_endpoints: 44
1 change: 1 addition & 0 deletions api.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ from dataherald.types import NlGenerationListResponse

Methods:

- <code title="post /api/prompts/sql-generations/nl-generations">client.nl_generations.<a href="./src/dataherald/resources/nl_generations.py">create</a>(\*\*<a href="src/dataherald/types/nl_generation_create_params.py">params</a>) -> <a href="./src/dataherald/types/shared/nl_generation_response.py">NlGenerationResponse</a></code>
- <code title="get /api/nl-generations/{id}">client.nl_generations.<a href="./src/dataherald/resources/nl_generations.py">retrieve</a>(id) -> <a href="./src/dataherald/types/shared/nl_generation_response.py">NlGenerationResponse</a></code>
- <code title="get /api/nl-generations">client.nl_generations.<a href="./src/dataherald/resources/nl_generations.py">list</a>(\*\*<a href="src/dataherald/types/nl_generation_list_params.py">params</a>) -> <a href="./src/dataherald/types/nl_generation_list_response.py">NlGenerationListResponse</a></code>

Expand Down
90 changes: 89 additions & 1 deletion src/dataherald/resources/nl_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import httpx

from ..types import NlGenerationListResponse, nl_generation_list_params
from ..types import NlGenerationListResponse, nl_generation_list_params, nl_generation_create_params
from .._types import (
NOT_GIVEN,
Body,
Expand All @@ -29,6 +29,47 @@ class NlGenerations(SyncAPIResource):
def with_raw_response(self) -> NlGenerationsWithRawResponse:
return NlGenerationsWithRawResponse(self)

def create(
self,
*,
sql_generation: nl_generation_create_params.SqlGeneration,
max_rows: int | NotGiven = NOT_GIVEN,
metadata: object | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> NlGenerationResponse:
"""
Create Prompt Sql Nl Generation

Args:
extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds
"""
return self._post(
"/api/prompts/sql-generations/nl-generations",
body=maybe_transform(
{
"sql_generation": sql_generation,
"max_rows": max_rows,
"metadata": metadata,
},
nl_generation_create_params.NlGenerationCreateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=NlGenerationResponse,
)

def retrieve(
self,
id: str,
Expand Down Expand Up @@ -112,6 +153,47 @@ class AsyncNlGenerations(AsyncAPIResource):
def with_raw_response(self) -> AsyncNlGenerationsWithRawResponse:
return AsyncNlGenerationsWithRawResponse(self)

async def create(
self,
*,
sql_generation: nl_generation_create_params.SqlGeneration,
max_rows: int | NotGiven = NOT_GIVEN,
metadata: object | NotGiven = NOT_GIVEN,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
extra_query: Query | None = None,
extra_body: Body | None = None,
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
) -> NlGenerationResponse:
"""
Create Prompt Sql Nl Generation

Args:
extra_headers: Send extra headers

extra_query: Add additional query parameters to the request

extra_body: Add additional JSON properties to the request

timeout: Override the client-level default timeout for this request, in seconds
"""
return await self._post(
"/api/prompts/sql-generations/nl-generations",
body=maybe_transform(
{
"sql_generation": sql_generation,
"max_rows": max_rows,
"metadata": metadata,
},
nl_generation_create_params.NlGenerationCreateParams,
),
options=make_request_options(
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
),
cast_to=NlGenerationResponse,
)

async def retrieve(
self,
id: str,
Expand Down Expand Up @@ -192,6 +274,9 @@ async def list(

class NlGenerationsWithRawResponse:
def __init__(self, nl_generations: NlGenerations) -> None:
self.create = to_raw_response_wrapper(
nl_generations.create,
)
self.retrieve = to_raw_response_wrapper(
nl_generations.retrieve,
)
Expand All @@ -202,6 +287,9 @@ def __init__(self, nl_generations: NlGenerations) -> None:

class AsyncNlGenerationsWithRawResponse:
def __init__(self, nl_generations: AsyncNlGenerations) -> None:
self.create = async_to_raw_response_wrapper(
nl_generations.create,
)
self.retrieve = async_to_raw_response_wrapper(
nl_generations.retrieve,
)
Expand Down
1 change: 1 addition & 0 deletions src/dataherald/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from .golden_sql_upload_response import GoldenSqlUploadResponse as GoldenSqlUploadResponse
from .sql_generation_list_params import SqlGenerationListParams as SqlGenerationListParams
from .table_description_response import TableDescriptionResponse as TableDescriptionResponse
from .nl_generation_create_params import NlGenerationCreateParams as NlGenerationCreateParams
from .nl_generation_list_response import NlGenerationListResponse as NlGenerationListResponse
from .sql_generation_create_params import SqlGenerationCreateParams as SqlGenerationCreateParams
from .sql_generation_list_response import SqlGenerationListResponse as SqlGenerationListResponse
Expand Down
35 changes: 35 additions & 0 deletions src/dataherald/types/nl_generation_create_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# File generated from our OpenAPI spec by Stainless.

from __future__ import annotations

from typing_extensions import Required, TypedDict

__all__ = ["NlGenerationCreateParams", "SqlGeneration", "SqlGenerationPrompt"]


class NlGenerationCreateParams(TypedDict, total=False):
sql_generation: Required[SqlGeneration]

max_rows: int

metadata: object


class SqlGenerationPrompt(TypedDict, total=False):
db_connection_id: Required[str]

text: Required[str]

metadata: object


class SqlGeneration(TypedDict, total=False):
prompt: Required[SqlGenerationPrompt]

evaluate: bool

finetuning_id: str

metadata: object

sql: str
90 changes: 90 additions & 0 deletions tests/api_resources/test_nl_generations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,51 @@ class TestNlGenerations:
loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False)
parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])

@parametrize
def test_method_create(self, client: Dataherald) -> None:
nl_generation = client.nl_generations.create(
sql_generation={
"prompt": {
"text": "string",
"db_connection_id": "string",
}
},
)
assert_matches_type(NlGenerationResponse, nl_generation, path=["response"])

@parametrize
def test_method_create_with_all_params(self, client: Dataherald) -> None:
nl_generation = client.nl_generations.create(
sql_generation={
"finetuning_id": "string",
"evaluate": True,
"sql": "string",
"metadata": {},
"prompt": {
"text": "string",
"db_connection_id": "string",
"metadata": {},
},
},
max_rows=0,
metadata={},
)
assert_matches_type(NlGenerationResponse, nl_generation, path=["response"])

@parametrize
def test_raw_response_create(self, client: Dataherald) -> None:
response = client.nl_generations.with_raw_response.create(
sql_generation={
"prompt": {
"text": "string",
"db_connection_id": "string",
}
},
)
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
nl_generation = response.parse()
assert_matches_type(NlGenerationResponse, nl_generation, path=["response"])

@parametrize
def test_method_retrieve(self, client: Dataherald) -> None:
nl_generation = client.nl_generations.retrieve(
Expand Down Expand Up @@ -65,6 +110,51 @@ class TestAsyncNlGenerations:
loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False)
parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"])

@parametrize
async def test_method_create(self, client: AsyncDataherald) -> None:
nl_generation = await client.nl_generations.create(
sql_generation={
"prompt": {
"text": "string",
"db_connection_id": "string",
}
},
)
assert_matches_type(NlGenerationResponse, nl_generation, path=["response"])

@parametrize
async def test_method_create_with_all_params(self, client: AsyncDataherald) -> None:
nl_generation = await client.nl_generations.create(
sql_generation={
"finetuning_id": "string",
"evaluate": True,
"sql": "string",
"metadata": {},
"prompt": {
"text": "string",
"db_connection_id": "string",
"metadata": {},
},
},
max_rows=0,
metadata={},
)
assert_matches_type(NlGenerationResponse, nl_generation, path=["response"])

@parametrize
async def test_raw_response_create(self, client: AsyncDataherald) -> None:
response = await client.nl_generations.with_raw_response.create(
sql_generation={
"prompt": {
"text": "string",
"db_connection_id": "string",
}
},
)
assert response.http_request.headers.get("X-Stainless-Lang") == "python"
nl_generation = response.parse()
assert_matches_type(NlGenerationResponse, nl_generation, path=["response"])

@parametrize
async def test_method_retrieve(self, client: AsyncDataherald) -> None:
nl_generation = await client.nl_generations.retrieve(
Expand Down