Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cosmos] make response_hook thread safe for the async client #34019

Merged
merged 1 commit into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,8 +766,7 @@ def _replace_throughput(
max_throughput = throughput.auto_scale_max_throughput
increment_percent = throughput.auto_scale_increment_percent
if max_throughput is not None:
new_throughput_properties['content']['offerAutopilotSettings'][
'maxThroughput'] = max_throughput
new_throughput_properties['content']['offerAutopilotSettings']['maxThroughput'] = max_throughput
if increment_percent:
new_throughput_properties['content']['offerAutopilotSettings']['autoUpgradePolicy']['throughputPolicy']['incrementPercent'] = increment_percent # pylint: disable=line-too-long
if throughput.offer_throughput:
Expand Down
46 changes: 4 additions & 42 deletions sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ async def read(
:returns: Dict representing the retrieved container.
:rtype: Dict[str, Any]
"""
response_hook = kwargs.pop('response_hook', None)
if session_token is not None:
kwargs['session_token'] = session_token
if priority_level is not None:
Expand All @@ -173,8 +172,6 @@ async def read(
self._properties = await self.client_connection.ReadContainer(
collection_link, options=request_options, **kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, self._properties)
return self._properties

@distributed_trace_async
Expand Down Expand Up @@ -220,7 +217,6 @@ async def create_item(
:returns: A dict representing the new item.
:rtype: Dict[str, Any]
"""
response_hook = kwargs.pop('response_hook', None)
if pre_trigger_include is not None:
kwargs['pre_trigger_include'] = pre_trigger_include
if post_trigger_include is not None:
Expand All @@ -243,8 +239,6 @@ async def create_item(
result = await self.client_connection.CreateItem(
database_or_container_link=self.container_link, document=body, options=request_options, **kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

@distributed_trace_async
Expand Down Expand Up @@ -291,7 +285,6 @@ async def read_item(
:caption: Get an item from the database and update one of its properties:
:name: update_item
"""
response_hook = kwargs.pop('response_hook', None)
doc_link = self._get_document_link(item)
if post_trigger_include is not None:
kwargs['post_trigger_include'] = post_trigger_include
Expand All @@ -309,10 +302,7 @@ async def read_item(
validate_cache_staleness_value(max_integrated_cache_staleness_in_ms)
request_options["maxIntegratedCacheStaleness"] = max_integrated_cache_staleness_in_ms

result = await self.client_connection.ReadItem(document_link=doc_link, options=request_options, **kwargs)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result
return await self.client_connection.ReadItem(document_link=doc_link, options=request_options, **kwargs)

@distributed_trace
def read_all_items(
Expand Down Expand Up @@ -571,7 +561,6 @@ async def upsert_item(
:returns: A dict representing the upserted item.
:rtype: Dict[str, Any]
"""
response_hook = kwargs.pop('response_hook', None)
if pre_trigger_include is not None:
kwargs['pre_trigger_include'] = pre_trigger_include
if post_trigger_include is not None:
Expand All @@ -595,8 +584,6 @@ async def upsert_item(
options=request_options,
**kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

@distributed_trace_async
Expand Down Expand Up @@ -639,7 +626,6 @@ async def replace_item(
:returns: A dict representing the item after replace went through.
:rtype: Dict[str, Any]
"""
response_hook = kwargs.pop('response_hook', None)
item_link = self._get_document_link(item)
if pre_trigger_include is not None:
kwargs['pre_trigger_include'] = pre_trigger_include
Expand All @@ -661,8 +647,6 @@ async def replace_item(
result = await self.client_connection.ReplaceItem(
document_link=item_link, new_document=body, options=request_options, **kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

@distributed_trace_async
Expand Down Expand Up @@ -708,7 +692,6 @@ async def patch_item(
given id does not exist.
:rtype: dict[str, Any]
"""
response_hook = kwargs.pop('response_hook', None)
if pre_trigger_include is not None:
kwargs['pre_trigger_include'] = pre_trigger_include
if post_trigger_include is not None:
Expand All @@ -728,11 +711,8 @@ async def patch_item(
request_options["filterPredicate"] = filter_predicate

item_link = self._get_document_link(item)
result = await self.client_connection.PatchItem(
return await self.client_connection.PatchItem(
document_link=item_link, operations=patch_operations, options=request_options, **kwargs)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

@distributed_trace_async
async def delete_item(
Expand Down Expand Up @@ -770,7 +750,6 @@ async def delete_item(
:raises ~azure.cosmos.exceptions.CosmosResourceNotFoundError: The item does not exist in the container.
:rtype: None
"""
response_hook = kwargs.pop('response_hook', None)
if pre_trigger_include is not None:
kwargs['pre_trigger_include'] = pre_trigger_include
if post_trigger_include is not None:
Expand All @@ -788,8 +767,6 @@ async def delete_item(

document_link = self._get_document_link(item)
await self.client_connection.DeleteItem(document_link=document_link, options=request_options, **kwargs)
if response_hook:
response_hook(self.client_connection.last_response_headers, None)

@distributed_trace_async
async def get_throughput(self, **kwargs: Any) -> ThroughputProperties:
Expand Down Expand Up @@ -842,7 +819,6 @@ async def replace_throughput(
:returns: ThroughputProperties for the container, updated with new throughput.
:rtype: ~azure.cosmos.offer.ThroughputProperties
"""
response_hook = kwargs.pop('response_hook', None)
properties = await self._get_properties()
link = properties["_self"]
query_spec = {
Expand All @@ -860,8 +836,6 @@ async def replace_throughput(
_replace_throughput(throughput=throughput, new_throughput_properties=new_offer)
data = await self.client_connection.ReplaceOffer(offer_link=throughput_properties[0]["_self"],
offer=throughput_properties[0], **kwargs)
if response_hook:
response_hook(self.client_connection.last_response_headers, data)

return ThroughputProperties(offer_throughput=data["content"]["offerThroughput"], properties=data)

Expand Down Expand Up @@ -955,13 +929,10 @@ async def get_conflict(
:rtype: Dict[str, Any]
"""
request_options = _build_options(kwargs)
response_hook = kwargs.pop('response_hook', None)
request_options["partitionKey"] = await self._set_partition_key(partition_key)
result = await self.client_connection.ReadConflict(
conflict_link=self._get_conflict_link(conflict), options=request_options, **kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result

@distributed_trace_async
Expand All @@ -986,13 +957,11 @@ async def delete_conflict(
:rtype: None
"""
request_options = _build_options(kwargs)
response_hook = kwargs.pop('response_hook', None)
request_options["partitionKey"] = await self._set_partition_key(partition_key)

await self.client_connection.DeleteConflict(
conflict_link=self._get_conflict_link(conflict), options=request_options, **kwargs
)
if response_hook:
response_hook(self.client_connection.last_response_headers, None)

@distributed_trace_async
async def delete_all_items_by_partition_key(
Expand Down Expand Up @@ -1023,7 +992,6 @@ async def delete_all_items_by_partition_key(
:keyword Callable response_hook: A callable invoked with the response metadata.
:rtype: None
"""
response_hook = kwargs.pop('response_hook', None)
if pre_trigger_include is not None:
kwargs['pre_trigger_include'] = pre_trigger_include
if post_trigger_include is not None:
Expand All @@ -1040,8 +1008,6 @@ async def delete_all_items_by_partition_key(

await self.client_connection.DeleteAllItemsByPartitionKey(collection_link=self.container_link,
options=request_options, **kwargs)
if response_hook:
response_hook(self.client_connection.last_response_headers, None)

@distributed_trace_async
async def execute_item_batch(
Expand Down Expand Up @@ -1074,7 +1040,6 @@ async def execute_item_batch(
:raises ~azure.cosmos.exceptions.CosmosBatchOperationError: A transactional batch operation failed in the batch.
:rtype: List[Dict[str, Any]]
"""
response_hook = kwargs.pop('response_hook', None)
if pre_trigger_include is not None:
kwargs['pre_trigger_include'] = pre_trigger_include
if post_trigger_include is not None:
Expand All @@ -1089,8 +1054,5 @@ async def execute_item_batch(
request_options["partitionKey"] = await self._set_partition_key(partition_key)
request_options["disableAutomaticIdGeneration"] = True

result = await self.client_connection.Batch(
return await self.client_connection.Batch(
collection_link=self.container_link, batch_operations=batch_operations, options=request_options, **kwargs)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return result
3 changes: 0 additions & 3 deletions sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,6 @@ async def create_database(
:caption: Create a database in the Cosmos DB account:
:name: create_database
"""
response_hook = kwargs.pop('response_hook', None)
if session_token is not None:
kwargs["session_token"] = session_token
if initial_headers is not None:
Expand All @@ -284,8 +283,6 @@ async def create_database(
_set_throughput_options(offer=offer_throughput, request_options=request_options)

result = await self.client_connection.CreateDatabase(database=dict(id=id), options=request_options, **kwargs)
if response_hook:
response_hook(self.client_connection.last_response_headers, result)
return DatabaseProxy(self.client_connection, id=result["id"], properties=result)

@distributed_trace_async
Expand Down
Loading
Loading