Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .chronus/changes/add-test-2024-7-8-18-12-59.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
changeKind: fix
packages:
- "@azure-tools/typespec-python"
---

Fix to get right response and exception
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .base import BaseType
from .enum_type import EnumType
from .model_type import ModelType
from .model_type import ModelType, UsageFlags
from .combined_type import CombinedType
from .client import Client
from .request_builder import RequestBuilder, OverloadedRequestBuilder
Expand Down Expand Up @@ -162,9 +162,7 @@ def model_types(self) -> List[ModelType]:
"""All of the model types in this class"""
if not self._model_types:
self._model_types = [
t
for t in self.types_map.values()
if isinstance(t, ModelType) and not (self.options["models_mode"] == "dpg" and t.page_result_model)
t for t in self.types_map.values() if isinstance(t, ModelType) and t.usage != UsageFlags.Default.value
]
return self._model_types

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ def __init__(
self._got_polymorphic_subtypes = False
self.internal: bool = self.yaml_data.get("internal", False)
self.snake_case_name: str = self.yaml_data["snakeCaseName"]
self.page_result_model: bool = self.yaml_data.get("pageResultModel", False)
self.cross_language_definition_id: Optional[str] = self.yaml_data.get("crossLanguageDefinitionId")
self.usage: int = self.yaml_data.get("usage", 0)
self.usage: int = self.yaml_data.get("usage", UsageFlags.Input.value | UsageFlags.Output.value)

@property
def is_usage_output(self) -> bool:
Expand Down
23 changes: 15 additions & 8 deletions packages/typespec-python/src/http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import {
getDescriptionAndSummary,
getImplementation,
isAbstract,
isAzureCoreModel,
isAzureCoreErrorResponse,
} from "./utils.js";
import { KnownTypes, getType } from "./types.js";
import { PythonSdkContext } from "./lib.js";
Expand Down Expand Up @@ -68,7 +68,7 @@ function emitInitialLroHttpMethod(
operationGroupName: string,
): Record<string, any> {
return {
...emitHttpOperation(context, rootClient, operationGroupName, method.operation),
...emitHttpOperation(context, rootClient, operationGroupName, method.operation, method),
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python SDK always set response type of initial operation as Binary type so just keep it same as begin_xxx API.

name: `_${camelToSnakeCase(method.name)}_initial`,
isLroInitialOperation: true,
wantTracing: false,
Expand Down Expand Up @@ -102,8 +102,8 @@ function addPagingInformation(
operationGroupName: string,
) {
for (const response of method.operation.responses.values()) {
if (response.type && !isAzureCoreModel(response.type)) {
getType(context, response.type)["pageResultModel"] = true;
if (response.type) {
getType(context, response.type)["usage"] = UsageFlags.None;
}
}
const itemType = getType(context, method.response.type!);
Expand Down Expand Up @@ -168,7 +168,7 @@ function emitHttpOperation(
responses.push(emitHttpResponse(context, statusCodes, response, method)!);
}
for (const [statusCodes, exception] of operation.exceptions) {
exceptions.push(emitHttpResponse(context, statusCodes, exception)!);
exceptions.push(emitHttpResponse(context, statusCodes, exception, undefined, true)!);
}
const result = {
url: operation.path,
Expand Down Expand Up @@ -326,13 +326,20 @@ function emitHttpResponse(
statusCodes: HttpStatusCodeRange | number | "*",
response: SdkHttpResponse,
method?: SdkServiceMethod<SdkHttpOperation>,
isException = false,
): Record<string, any> | undefined {
if (!response) return undefined;
let type = undefined;
if (response.type && !isAzureCoreModel(response.type)) {
if (isException) {
if (response.type && !isAzureCoreErrorResponse(response.type)) {
type = getType(context, response.type);
}
} else if (method && !method.kind.includes("basic")) {
if (method.response.type) {
type = getType(context, method.response.type);
}
} else if (response.type) {
Comment on lines +333 to +341
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • For exception response, need to filter Azure.Core.ErrorResponse
  • For normal response, there are 2 kinds:
  1. lro/lropaging/paging API: there shall always be one kind of response type so use method.response.type directly
  2. basic API: there may be multi response types, so honor http response type.

type = getType(context, response.type);
} else if (method && method.response.type && !isAzureCoreModel(method.response.type)) {
type = getType(context, method.response.type);
}
return {
headers: response.headers.map((x) => emitHttpResponseHeader(context, x)),
Expand Down
5 changes: 3 additions & 2 deletions packages/typespec-python/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,15 @@ export function emitParamBase<TServiceOperation extends SdkServiceOperation>(
};
}

export function isAzureCoreModel(t: SdkType | undefined): boolean {
export function isAzureCoreErrorResponse(t: SdkType | undefined): boolean {
if (!t) return false;
const tspType = t.__raw;
if (!tspType) return false;
return (
tspType.kind === "Model" &&
tspType.namespace !== undefined &&
["Azure.Core", "Azure.Core.Foundations"].includes(getNamespaceFullName(tspType.namespace))
["Azure.Core", "Azure.Core.Foundations"].includes(getNamespaceFullName(tspType.namespace)) &&
tspType.name === "ErrorResponse"
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,20 @@ def __init__(self, *args, **kwargs) -> None:
self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")

@distributed_trace_async
async def get(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return-statements
async def get(self, **kwargs: Any) -> List[int]:
"""get an embedding vector.

:return: None
:rtype: None
:return: list of int
:rtype: list[int]
:raises ~azure.core.exceptions.HttpResponseError:

Example:
.. code-block:: python

# response body for status code(s): 200
response == [
0
]
"""
error_map: MutableMapping[int, Type[HttpResponseError]] = {
401: ClientAuthenticationError,
Expand All @@ -79,7 +87,7 @@ async def get(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-retu
_headers = kwargs.pop("headers", {}) or {}
_params = kwargs.pop("params", {}) or {}

cls: ClsType[None] = kwargs.pop("cls", None)
cls: ClsType[List[int]] = kwargs.pop("cls", None)

_request = build_azure_core_embedding_vector_get_request(
headers=_headers,
Expand All @@ -90,19 +98,31 @@ async def get(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-retu
}
_request.url = self._client.format_url(_request.url, **path_format_arguments)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [200]:
if _stream:
try:
await response.read() # Load the body in memory and close the socket
except (StreamConsumedError, StreamClosedError):
pass
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

if _stream:
deserialized = response.iter_bytes()
else:
deserialized = _deserialize(List[int], response.json())

if cls:
return cls(pipeline_response, None, {}) # type: ignore
return cls(pipeline_response, deserialized, {}) # type: ignore

return deserialized # type: ignore

@overload
async def put( # pylint: disable=inconsistent-return-statements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,20 @@ def __init__(self, *args, **kwargs):
self._deserialize = input_args.pop(0) if input_args else kwargs.pop("deserializer")

@distributed_trace
def get(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return-statements
def get(self, **kwargs: Any) -> List[int]:
"""get an embedding vector.

:return: None
:rtype: None
:return: list of int
:rtype: list[int]
:raises ~azure.core.exceptions.HttpResponseError:

Example:
.. code-block:: python

# response body for status code(s): 200
response == [
0
]
"""
error_map: MutableMapping[int, Type[HttpResponseError]] = {
401: ClientAuthenticationError,
Expand All @@ -123,7 +131,7 @@ def get(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return-sta
_headers = kwargs.pop("headers", {}) or {}
_params = kwargs.pop("params", {}) or {}

cls: ClsType[None] = kwargs.pop("cls", None)
cls: ClsType[List[int]] = kwargs.pop("cls", None)

_request = build_azure_core_embedding_vector_get_request(
headers=_headers,
Expand All @@ -134,19 +142,31 @@ def get(self, **kwargs: Any) -> None: # pylint: disable=inconsistent-return-sta
}
_request.url = self._client.format_url(_request.url, **path_format_arguments)

_stream = False
_stream = kwargs.pop("stream", False)
pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access
_request, stream=_stream, **kwargs
)

response = pipeline_response.http_response

if response.status_code not in [200]:
if _stream:
try:
response.read() # Load the body in memory and close the socket
except (StreamConsumedError, StreamClosedError):
pass
map_error(status_code=response.status_code, response=response, error_map=error_map)
raise HttpResponseError(response=response)

if _stream:
deserialized = response.iter_bytes()
else:
deserialized = _deserialize(List[int], response.json())

if cls:
return cls(pipeline_response, None, {}) # type: ignore
return cls(pipeline_response, deserialized, {}) # type: ignore

return deserialized # type: ignore

@overload
def put( # pylint: disable=inconsistent-return-statements
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import pytest
from specs.azure.core.model.aio import ModelClient
from specs.azure.core.model.models import AzureEmbeddingModel


@pytest.fixture
async def client():
async with ModelClient() as client:
yield client


@pytest.mark.asyncio
async def test_azure_core_embedding_vector_post(client: ModelClient):
embedding_model = AzureEmbeddingModel(embedding=[0, 1, 2, 3, 4])
result = await client.azure_core_embedding_vector.post(
body=embedding_model,
)
assert result == AzureEmbeddingModel(embedding=[5, 6, 7, 8, 9])


@pytest.mark.asyncio
async def test_azure_core_embedding_vector_put(client: ModelClient):
await client.azure_core_embedding_vector.put(body=[0, 1, 2, 3, 4])


@pytest.mark.asyncio
async def test_azure_core_embedding_vector_get(client: ModelClient):
assert [0, 1, 2, 3, 4] == (await client.azure_core_embedding_vector.get())
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import pytest
from specs.azure.core.model import ModelClient
from specs.azure.core.model.models import AzureEmbeddingModel


@pytest.fixture
def client():
with ModelClient() as client:
yield client


def test_azure_core_embedding_vector_post(client: ModelClient):
embedding_model = AzureEmbeddingModel(embedding=[0, 1, 2, 3, 4])
result = client.azure_core_embedding_vector.post(
body=embedding_model,
)
assert result == AzureEmbeddingModel(embedding=[5, 6, 7, 8, 9])


def test_azure_core_embedding_vector_put(client: ModelClient):
client.azure_core_embedding_vector.put(body=[0, 1, 2, 3, 4])


def test_azure_core_embedding_vector_get(client: ModelClient):
assert [0, 1, 2, 3, 4] == client.azure_core_embedding_vector.get()
1 change: 1 addition & 0 deletions packages/typespec-python/test/azure/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ azure-mgmt-core==1.3.2
-e ./generated/azure-core-scalar
-e ./generated/azurecore-lro-rpc
-e ./generated/azure-core-lro-standard
-e ./generated/azure-core-model
-e ./generated/azure-core-traits
-e ./generated/azure-core-page
-e ./generated/azure-special-headers-client-request-id/
Expand Down