Skip to content

Commit

Permalink
Multi-vector search (#31392)
Browse files Browse the repository at this point in the history
* Multi-vector search

* update changelog

* update async sample
  • Loading branch information
xiangyan99 committed Aug 1, 2023
1 parent 7b73d21 commit fbbce42
Show file tree
Hide file tree
Showing 48 changed files with 284 additions and 269 deletions.
2 changes: 2 additions & 0 deletions sdk/search/azure-search-documents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

- Added multi-vector search support. Now instead of passing in `vector`, `top_k` and `vector_fields`, search method accepts `vectors` which is a list of `Vector` object.

### Breaking Changes

### Bugs Fixed
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -662,8 +662,9 @@ def _serialize(self, target_obj, data_type=None, **kwargs):
_serialized.update(_new_attr) # type: ignore
_new_attr = _new_attr[k] # type: ignore
_serialized = _serialized[k]
except ValueError:
continue
except ValueError as err:
if isinstance(err, SerializationError):
raise

except (AttributeError, KeyError, TypeError) as err:
msg = "Attribute {} in object {} cannot be serialized.\n{}".format(attr_name, class_name, str(target_obj))
Expand Down Expand Up @@ -741,6 +742,8 @@ def query(self, name, data, data_type, **kwargs):
:param data: The data to be serialized.
:param str data_type: The type to be serialized from.
:keyword bool skip_quote: Whether to skip quote the serialized result.
Defaults to False.
:rtype: str
:raises: TypeError if serialization fails.
:raises: ValueError if data is None
Expand All @@ -749,10 +752,8 @@ def query(self, name, data, data_type, **kwargs):
# Treat the list aside, since we don't want to encode the div separator
if data_type.startswith("["):
internal_data_type = data_type[1:-1]
data = [self.serialize_data(d, internal_data_type, **kwargs) if d is not None else "" for d in data]
if not kwargs.get("skip_quote", False):
data = [quote(str(d), safe="") for d in data]
return str(self.serialize_iter(data, internal_data_type, **kwargs))
do_quote = not kwargs.get("skip_quote", False)
return str(self.serialize_iter(data, internal_data_type, do_quote=do_quote, **kwargs))

# Not a list, regular serialization
output = self.serialize_data(data, data_type, **kwargs)
Expand Down Expand Up @@ -891,6 +892,8 @@ def serialize_iter(self, data, iter_type, div=None, **kwargs):
not be None or empty.
:param str div: If set, this str will be used to combine the elements
in the iterable into a combined string. Default is 'None'.
:keyword bool do_quote: Whether to quote the serialized result of each iterable element.
Defaults to False.
:rtype: list, str
"""
if isinstance(data, str):
Expand All @@ -903,9 +906,14 @@ def serialize_iter(self, data, iter_type, div=None, **kwargs):
for d in data:
try:
serialized.append(self.serialize_data(d, iter_type, **kwargs))
except ValueError:
except ValueError as err:
if isinstance(err, SerializationError):
raise
serialized.append(None)

if kwargs.get("do_quote", False):
serialized = ["" if s is None else quote(str(s), safe="") for s in serialized]

if div:
serialized = ["" if s is None else str(s) for s in serialized]
serialized = div.join(serialized)
Expand Down Expand Up @@ -950,7 +958,9 @@ def serialize_dict(self, attr, dict_type, **kwargs):
for key, value in attr.items():
try:
serialized[self.serialize_unicode(key)] = self.serialize_data(value, dict_type, **kwargs)
except ValueError:
except ValueError as err:
if isinstance(err, SerializationError):
raise
serialized[self.serialize_unicode(key)] = None

if "xml" in serialization_ctxt:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

from typing import List, cast

from azure.core.pipeline.transport import HttpRequest


Expand All @@ -14,15 +12,3 @@ def _convert_request(request, files=None):
if files:
request.set_formdata_body(files)
return request


def _format_url_section(template, **kwargs):
components = template.split("/")
while components:
try:
return template.format(**kwargs)
except KeyError as key:
# Need the cast, as for some reasons "split" is typed as list[str | Any]
formatted_components = cast(List[str], template.split("/"))
components = [c for c in formatted_components if "{}".format(key.args[0]) not in c]
template = "/".join(components)
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# pylint: disable=too-many-lines
# coding=utf-8
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------
from io import IOBase
import sys
from typing import Any, Callable, Dict, IO, List, Optional, TypeVar, Union, overload

from azure.core.exceptions import (
Expand Down Expand Up @@ -36,11 +35,6 @@
build_suggest_post_request,
)

if sys.version_info >= (3, 9):
from collections.abc import MutableMapping
else:
from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports
JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object
T = TypeVar("T")
ClsType = Optional[Callable[[PipelineResponse[HttpRequest, AsyncHttpResponse], T, Dict[str, Any]], Any]]

Expand Down Expand Up @@ -442,7 +436,7 @@ async def get(
selected_fields: Optional[List[str]] = None,
request_options: Optional[_models.RequestOptions] = None,
**kwargs: Any
) -> JSON:
) -> Dict[str, Any]:
"""Retrieves a document from the index.
.. seealso::
Expand All @@ -456,8 +450,8 @@ async def get(
:param request_options: Parameter group. Default value is None.
:type request_options: ~search_index_client.models.RequestOptions
:keyword callable cls: A custom type or function that will be passed the direct response
:return: JSON or the result of cls(response)
:rtype: JSON
:return: dict mapping str to any or the result of cls(response)
:rtype: dict[str, any]
:raises ~azure.core.exceptions.HttpResponseError:
"""
error_map = {
Expand All @@ -472,7 +466,7 @@ async def get(
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version))
cls: ClsType[JSON] = kwargs.pop("cls", None)
cls: ClsType[Dict[str, Any]] = kwargs.pop("cls", None)

_x_ms_client_request_id = None
if request_options is not None:
Expand Down Expand Up @@ -506,7 +500,7 @@ async def get(
error = self._deserialize.failsafe_deserialize(_models.SearchError, pipeline_response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize("object", pipeline_response)
deserialized = self._deserialize("{object}", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {})
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# coding=utf-8
# pylint: disable=too-many-lines
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

Expand Down Expand Up @@ -1233,6 +1233,8 @@ class SearchRequest(_serialization.Model): # pylint: disable=too-many-instance-
:vartype semantic_fields: str
:ivar vector: The query parameters for vector and hybrid search queries.
:vartype vector: ~search_index_client.models.Vector
:ivar vectors: The query parameters for multi-vector search queries.
:vartype vectors: list[~search_index_client.models.Vector]
"""

_validation = {
Expand Down Expand Up @@ -1269,6 +1271,7 @@ class SearchRequest(_serialization.Model): # pylint: disable=too-many-instance-
"captions": {"key": "captions", "type": "str"},
"semantic_fields": {"key": "semanticFields", "type": "str"},
"vector": {"key": "vector", "type": "Vector"},
"vectors": {"key": "vectors", "type": "[Vector]"},
}

def __init__( # pylint: disable=too-many-locals
Expand Down Expand Up @@ -1303,6 +1306,7 @@ def __init__( # pylint: disable=too-many-locals
captions: Optional[Union[str, "_models.QueryCaptionType"]] = None,
semantic_fields: Optional[str] = None,
vector: Optional["_models.Vector"] = None,
vectors: Optional[List["_models.Vector"]] = None,
**kwargs: Any
) -> None:
"""
Expand Down Expand Up @@ -1421,6 +1425,8 @@ def __init__( # pylint: disable=too-many-locals
:paramtype semantic_fields: str
:keyword vector: The query parameters for vector and hybrid search queries.
:paramtype vector: ~search_index_client.models.Vector
:keyword vectors: The query parameters for multi-vector search queries.
:paramtype vectors: list[~search_index_client.models.Vector]
"""
super().__init__(**kwargs)
self.include_total_result_count = include_total_result_count
Expand Down Expand Up @@ -1452,6 +1458,7 @@ def __init__( # pylint: disable=too-many-locals
self.captions = captions
self.semantic_fields = semantic_fields
self.vector = vector
self.vectors = vectors


class SearchResult(_serialization.Model):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# coding=utf-8
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# pylint: disable=too-many-lines
# coding=utf-8
# --------------------------------------------------------------------------
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.5, generator: @autorest/python@6.4.11)
# Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.9.7, generator: @autorest/python@6.7.1)
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------
from io import IOBase
import sys
from typing import Any, Callable, Dict, IO, List, Optional, TypeVar, Union, overload

from azure.core.exceptions import (
Expand All @@ -24,13 +23,8 @@

from .. import models as _models
from .._serialization import Serializer
from .._vendor import _convert_request, _format_url_section
from .._vendor import _convert_request

if sys.version_info >= (3, 9):
from collections.abc import MutableMapping
else:
from typing import MutableMapping # type: ignore # pylint: disable=ungrouped-imports
JSON = MutableMapping[str, Any] # pylint: disable=unsubscriptable-object
T = TypeVar("T")
ClsType = Optional[Callable[[PipelineResponse[HttpRequest, HttpResponse], T, Dict[str, Any]], Any]]

Expand Down Expand Up @@ -215,7 +209,7 @@ def build_get_request(
"key": _SERIALIZER.url("key", key, "str"),
}

_url: str = _format_url_section(_url, **path_format_arguments) # type: ignore
_url: str = _url.format(**path_format_arguments) # type: ignore

# Construct parameters
if selected_fields is not None:
Expand Down Expand Up @@ -808,7 +802,7 @@ def get(
selected_fields: Optional[List[str]] = None,
request_options: Optional[_models.RequestOptions] = None,
**kwargs: Any
) -> JSON:
) -> Dict[str, Any]:
"""Retrieves a document from the index.
.. seealso::
Expand All @@ -822,8 +816,8 @@ def get(
:param request_options: Parameter group. Default value is None.
:type request_options: ~search_index_client.models.RequestOptions
:keyword callable cls: A custom type or function that will be passed the direct response
:return: JSON or the result of cls(response)
:rtype: JSON
:return: dict mapping str to any or the result of cls(response)
:rtype: dict[str, any]
:raises ~azure.core.exceptions.HttpResponseError:
"""
error_map = {
Expand All @@ -838,7 +832,7 @@ def get(
_params = case_insensitive_dict(kwargs.pop("params", {}) or {})

api_version: str = kwargs.pop("api_version", _params.pop("api-version", self._config.api_version))
cls: ClsType[JSON] = kwargs.pop("cls", None)
cls: ClsType[Dict[str, Any]] = kwargs.pop("cls", None)

_x_ms_client_request_id = None
if request_options is not None:
Expand Down Expand Up @@ -872,7 +866,7 @@ def get(
error = self._deserialize.failsafe_deserialize(_models.SearchError, pipeline_response)
raise HttpResponseError(response=response, model=error)

deserialized = self._deserialize("object", pipeline_response)
deserialized = self._deserialize("{object}", pipeline_response)

if cls:
return cls(pipeline_response, deserialized, {})
Expand Down

0 comments on commit fbbce42

Please sign in to comment.