diff --git a/deeplake/core/vectorstore/dataset_handlers/client_side_dataset_handler.py b/deeplake/core/vectorstore/dataset_handlers/client_side_dataset_handler.py index f587d8b5fc..6445db99b1 100644 --- a/deeplake/core/vectorstore/dataset_handlers/client_side_dataset_handler.py +++ b/deeplake/core/vectorstore/dataset_handlers/client_side_dataset_handler.py @@ -168,6 +168,7 @@ def search( return_tensors: List[str], return_view: bool, deep_memory: bool, + return_tql: bool, ) -> Union[Dict, Dataset]: feature_report_path( path=self.bugout_reporting_path, @@ -243,7 +244,7 @@ def search( embedding_tensor=embedding_tensor, return_tensors=return_tensors, return_view=return_view, - deep_memory=deep_memory, + return_tql=return_tql, token=self.token, org_id=self.org_id, ) diff --git a/deeplake/core/vectorstore/deep_memory/deep_memory.py b/deeplake/core/vectorstore/deep_memory/deep_memory.py index 2656b8f3d3..4d9370a727 100644 --- a/deeplake/core/vectorstore/deep_memory/deep_memory.py +++ b/deeplake/core/vectorstore/deep_memory/deep_memory.py @@ -1,4 +1,5 @@ import logging +import pathlib import uuid from collections import defaultdict from pydantic import BaseModel, ValidationError @@ -8,13 +9,13 @@ import numpy as np import deeplake -from deeplake.enterprise.dataloader import indra_available from deeplake.util.exceptions import ( + DeepMemoryWaitingListError, DeepMemoryWaitingListError, IncorrectRelevanceTypeError, IncorrectQueriesTypeError, ) -from deeplake.util.remove_cache import get_base_storage +from deeplake.util.path import convert_pathlib_to_string_if_needed from deeplake.constants import ( DEFAULT_QUERIES_VECTORSTORE_TENSORS, DEFAULT_MEMORY_CACHE_SIZE, @@ -30,7 +31,15 @@ feature_report_path, ) from deeplake.util.path import get_path_type -from deeplake.util.version_control import load_meta + + +def access_control(func): + def wrapper(self, *args, **kwargs): + if self.client is None: + raise DeepMemoryWaitingListError() + return func(self, *args, **kwargs) + + return wrapper def use_deep_memory(func): @@ -46,15 +55,6 @@ def wrapper(self, *args, **kwargs): return wrapper -def access_control(func): - def wrapper(self, *args, **kwargs): - if self.client is None: - raise DeepMemoryWaitingListError() - return func(self, *args, **kwargs) - - return wrapper - - class Relevance(BaseModel): data: List[List[Tuple[str, int]]] @@ -78,7 +78,8 @@ def validate_relevance_and_queries(relevance, queries): class DeepMemory: def __init__( self, - dataset_or_path: Union[Dataset, str], + dataset: Dataset, + path: Union[str, pathlib.Path], logger: logging.Logger, embedding_function: Optional[Any] = None, token: Optional[str] = None, @@ -87,7 +88,8 @@ def __init__( """Based Deep Memory class to train and evaluate models on DeepMemory managed service. Args: - dataset_or_path (Union[Dataset, str]): deeplake dataset object or path. + dataset (Dataset): deeplake dataset object or path. + path (Union[str, pathlib.Path]): Path to the dataset. logger (logging.Logger): Logger object. embedding_function (Optional[Any], optional): Embedding funtion class used to convert queries/documents to embeddings. Defaults to None. token (Optional[str], optional): API token for the DeepMemory managed service. Defaults to None. @@ -95,16 +97,9 @@ def __init__( Raises: ImportError: if indra is not installed - ValueError: if incorrect type is specified for `dataset_or_path` """ - if isinstance(dataset_or_path, Dataset): - self.path = dataset_or_path.path - elif isinstance(dataset_or_path, str): - self.path = dataset_or_path - else: - raise ValueError( - "dataset_or_path should be a Dataset object or a string path" - ) + self.dataset = dataset + self.path = convert_pathlib_to_string_if_needed(path) feature_report_path( path=self.path, @@ -143,7 +138,8 @@ def train( relevance (List[List[Tuple[str, int]]]): List of relevant documents for each query with their respective relevance score. The outer list corresponds to the queries and the inner list corresponds to the doc_id, relevence_score pair for each query. doc_id is the document id in the corpus dataset. It is stored in the `id` tensor of the corpus dataset. - relevence_score is the relevance score of the document for the query. The range is between 0 and 1, where 0 stands for not relevant and 1 stands for relevant. + relevence_score is the relevance score of the document for the query. The value is either 0 and 1, where 0 stands for not relevant (unknown relevance) + and 1 stands for relevant. Currently, only values of 1 contribute to the training, and there is no reason to provide examples with relevance of 0. embedding_function (Optional[Callable[[str], np.ndarray]], optional): Embedding funtion used to convert queries to embeddings. Defaults to None. token (str, optional): API token for the DeepMemory managed service. Defaults to None. @@ -178,7 +174,7 @@ def train( ) if embedding_function is None and self.embedding_function is not None: - embedding_function = self.embedding_function.embed_documents + embedding_function = self.embedding_function runtime = None if get_path_type(corpus_path) == "hub": @@ -484,10 +480,8 @@ def evaluate( if embedding is not None: query_embs = embedding else: - if self.embedding_function is not None: - embedding_function = ( - embedding_function or self.embedding_function.embed_documents - ) + if self.embedding_function is not None and embedding_function is None: + embedding_function = self.embedding_function if embedding_function is None: raise ValueError( @@ -554,6 +548,46 @@ def evaluate( self.queries_dataset.commit() return recalls + @access_control + def get_model(self): + """Get the name of the model currently being used by DeepMemory managed service.""" + return self.dataset.embedding.info["deepmemory"]["model.npy"]["job_id"] + + @access_control + def set_model(self, model_name: str): + """Set model.npy to use `model_name` instead of default model + Args: + model_name (str): name of the model to use + """ + + if "npy" not in model_name: + model_name += ".npy" + + # verify model_name + self._verify_model_name(model_name) + + # set model.npy to use `model_name` instead of default model + self._set_model_npy(model_name) + + def _verify_model_name(self, model_name: str): + if model_name not in self.dataset.embedding.info["deepmemory"]: + raise ValueError( + "Invalid model name. Please choose from the following models: " + + ", ".join(self.dataset.embedding.info["deepmemory"].keys()) + ) + + def _set_model_npy(self, model_name: str): + # get new model.npy + new_model_npy = self.dataset.embedding.info["deepmemory"][model_name] + + # get old deepmemory dictionary and update it: + old_deepmemory = self.dataset.embedding.info["deepmemory"] + new_deepmemory = old_deepmemory.copy() + new_deepmemory.update({"model.npy": new_model_npy}) + + # assign new deepmemory dictionary to the dataset: + self.dataset.embedding.info["deepmemory"] = new_deepmemory + def _get_dm_client(self): path = self.path path_type = get_path_type(path) diff --git a/deeplake/core/vectorstore/deep_memory/test_deepmemory.py b/deeplake/core/vectorstore/deep_memory/test_deepmemory.py index 0a458974fa..cda3232dce 100644 --- a/deeplake/core/vectorstore/deep_memory/test_deepmemory.py +++ b/deeplake/core/vectorstore/deep_memory/test_deepmemory.py @@ -3,6 +3,7 @@ import pytest import sys from time import sleep +from unittest.mock import MagicMock import deeplake from deeplake import VectorStore @@ -40,9 +41,9 @@ def test_deepmemory_init(hub_cloud_path, hub_cloud_dev_token): assert db.deep_memory is not None -def embedding_fn(texts): +def embedding_fn(texts, embedding_dim=1536): return [ - np.random.uniform(low=-10, high=10, size=(1536)).astype(np.float32) + np.random.uniform(low=-10, high=10, size=(embedding_dim)).astype(np.float32) for _ in range(len(texts)) ] @@ -432,7 +433,7 @@ def test_deepmemory_evaluate_with_embedding_func_in_init( path=corpus, runtime={"tensor_db": True}, token=hub_cloud_dev_token, - embedding_function=DummyEmbedder, + embedding_function=embedding_fn, ) recall = db.deep_memory.evaluate( queries=queries, @@ -584,7 +585,10 @@ def test_deepmemory_search_on_local_datasets( @requires_libdeeplake def test_unsupported_deepmemory_users(local_ds): dm = DeepMemory( - dataset_or_path=local_ds, logger=logger, embedding_function=DummyEmbedder + path=local_ds, + dataset=None, + logger=logger, + embedding_function=DummyEmbedder, ) with pytest.raises(DeepMemoryWaitingListError): dm.train( @@ -660,3 +664,121 @@ def test_not_supported_training_args(corpus_query_relevances_copy, hub_cloud_dev queries=queries, relevance="relevances", ) + + +def test_deepmemory_v2_set_model_should_set_model_for_all_subsequent_loads( + local_dmv2_dataset, + hub_cloud_dev_token, +): + # Setiing model should set model for all subsequent loads + db = VectorStore(path=local_dmv2_dataset, token=hub_cloud_dev_token) + assert db.deep_memory.get_model() == "655f86e8ab93e7fc5067a3ac_2" + + # ensure after setting model, get model returns specified model + db.deep_memory.set_model("655f86e8ab93e7fc5067a3ac_1") + + assert ( + db.dataset.embedding.info["deepmemory"]["model.npy"]["job_id"] + == "655f86e8ab93e7fc5067a3ac_1" + ) + assert db.deep_memory.get_model() == "655f86e8ab93e7fc5067a3ac_1" + + # ensure after setting model, reloading the dataset returns the same model + db = VectorStore(path=local_dmv2_dataset, token=hub_cloud_dev_token) + assert db.deep_memory.get_model() == "655f86e8ab93e7fc5067a3ac_1" + + +@pytest.mark.slow +@pytest.mark.skipif(sys.platform == "win32", reason="Does not run on Windows") +def test_deepmemory_search_should_contain_correct_answer( + corpus_query_relevances_copy, + testing_relevance_query_deepmemory, + hub_cloud_dev_token, +): + corpus, _, _, _ = corpus_query_relevances_copy + relevance, query_embedding = testing_relevance_query_deepmemory + + db = VectorStore( + path=corpus, + token=hub_cloud_dev_token, + ) + + output = db.search( + embedding=query_embedding, deep_memory=True, return_tensors=["id"] + ) + assert len(output["id"]) == 4 + assert relevance in output["id"] + + +@pytest.mark.slow +@pytest.mark.skipif(sys.platform == "win32", reason="Does not run on Windows") +def test_deeplake_search_should_not_contain_correct_answer( + corpus_query_relevances_copy, + testing_relevance_query_deepmemory, + hub_cloud_dev_token, +): + corpus, _, _, _ = corpus_query_relevances_copy + relevance, query_embedding = testing_relevance_query_deepmemory + + db = VectorStore( + path=corpus, + token=hub_cloud_dev_token, + ) + output = db.search(embedding=query_embedding) + assert len(output["id"]) == 4 + assert relevance not in output["id"] + + +@pytest.mark.slow +@pytest.mark.flaky(reruns=3) +@pytest.mark.skipif(sys.platform == "win32", reason="Does not run on Windows") +def test_deepmemory_train_with_embedding_function_specified_in_constructor_should_not_throw_any_exception( + deepmemory_small_dataset_copy, + hub_cloud_dev_token, +): + corpus, queries, relevances, _ = deepmemory_small_dataset_copy + + db = VectorStore( + path=corpus, + runtime={"tensor_db": True}, + token=hub_cloud_dev_token, + embedding_function=embedding_fn, + ) + + job_id = db.deep_memory.train( + queries=queries, + relevance=relevances, + ) + + +@pytest.mark.slow +@pytest.mark.flaky(reruns=3) +@pytest.mark.skipif(sys.platform == "win32", reason="Does not run on Windows") +def test_deepmemory_evaluate_with_embedding_function_specified_in_constructor_should_not_throw_any_exception( + corpus_query_pair_path, + hub_cloud_dev_token, +): + corpus, queries = corpus_query_pair_path + + db = VectorStore( + path=corpus, + runtime={"tensor_db": True}, + token=hub_cloud_dev_token, + embedding_function=embedding_fn, + ) + + queries_vs = VectorStore( + path=queries, + runtime={"tensor_db": True}, + token=hub_cloud_dev_token, + embedding_function=embedding_fn, + ) + + queries = queries_vs.dataset[:10].text.data()["value"] + relevance = queries_vs.dataset[:10].metadata.data()["value"] + relevance = [rel["relevance"] for rel in relevance] + + recall = db.deep_memory.evaluate( + queries=queries, + relevance=relevance, + ) diff --git a/deeplake/core/vectorstore/deeplake_vectorstore.py b/deeplake/core/vectorstore/deeplake_vectorstore.py index 69e643e7d3..61cd5a7ecc 100644 --- a/deeplake/core/vectorstore/deeplake_vectorstore.py +++ b/deeplake/core/vectorstore/deeplake_vectorstore.py @@ -8,6 +8,8 @@ from deeplake.core.dataset import Dataset from deeplake.core.vectorstore.dataset_handlers import get_dataset_handler from deeplake.core.vectorstore.deep_memory import DeepMemory +from deeplake.core.vectorstore.dataset_handlers import get_dataset_handler +from deeplake.core.vectorstore.deep_memory import DeepMemory from deeplake.constants import ( DEFAULT_VECTORSTORE_TENSORS, MAX_BYTES_PER_MINUTE, @@ -131,7 +133,8 @@ def __init__( ) self.deep_memory = DeepMemory( - dataset_or_path=self.dataset_handler.path, + dataset=self.dataset_handler.dataset, + path=self.dataset_handler.path, token=self.dataset_handler.token, logger=logger, embedding_function=embedding_function, @@ -240,6 +243,7 @@ def search( return_tensors: Optional[List[str]] = None, return_view: bool = False, deep_memory: bool = False, + return_tql: bool = False, ) -> Union[Dict, Dataset]: """VectorStore search method that combines embedding search, metadata search, and custom TQL search. @@ -290,6 +294,7 @@ def search( return_view (bool): Return a Deep Lake dataset view that satisfied the search parameters, instead of a dictionary with data. Defaults to False. If ``True`` return_tensors is set to "*" beucase data is lazy-loaded and there is no cost to including all tensors in the view. deep_memory (bool): Whether to use the Deep Memory model for improving search results. Defaults to False if deep_memory is not specified in the Vector Store initialization. If True, the distance metric is set to "deepmemory_distance", which represents the metric with which the model was trained. The search is performed using the Deep Memory model. If False, the distance metric is set to "COS" or whatever distance metric user specifies. + return_tql (bool): Whether to return the TQL query string used for the search. Defaults to False. .. # noqa: DAR101 @@ -317,6 +322,7 @@ def search( embedding_tensor=embedding_tensor, return_tensors=return_tensors, return_view=return_view, + return_tql=return_tql, deep_memory=deep_memory, ) diff --git a/deeplake/core/vectorstore/test_deeplake_vectorstore.py b/deeplake/core/vectorstore/test_deeplake_vectorstore.py index 7194a62906..5c25721af2 100644 --- a/deeplake/core/vectorstore/test_deeplake_vectorstore.py +++ b/deeplake/core/vectorstore/test_deeplake_vectorstore.py @@ -1299,10 +1299,14 @@ def create_and_populate_vs( overwrite=True, verbose=False, exec_option="compute_engine", - index_params={"threshold": 10}, + index_params={"threshold": -1}, number_of_data=NUMBER_OF_DATA, + runtime=None, ): - # TODO: cache the vectostore object and reuse it in other tests (maybe with deepcopy) + # if runtime specified and tensor_db is enabled, then set exec_option to None + if runtime and runtime.get("tensor_db", False): + exec_option = None + vector_store = DeepLakeVectorStore( path=path, overwrite=overwrite, @@ -1310,6 +1314,7 @@ def create_and_populate_vs( exec_option=exec_option, index_params=index_params, token=token, + runtime=runtime, ) utils.create_data(number_of_data=number_of_data, embedding_dim=EMBEDDING_DIM) @@ -1357,7 +1362,6 @@ def test_update_embedding_row_ids_and_filter_specified_should_throw_exception( ) embedding_fn = get_embedding_function() - # calling update_embedding with both ids and filter being specified with pytest.raises(ValueError): vector_store.update_embedding( row_ids=vector_store_row_ids, @@ -1381,6 +1385,7 @@ def test_update_embedding_query_and_filter_specified_should_throw_exception( embedding_fn = get_embedding_function() # calling update_embedding with both query and filter being specified + with pytest.raises(ValueError): vector_store.update_embedding( filter=vector_store_filters, @@ -2851,6 +2856,46 @@ def test_vs_commit(local_path): assert len(db) == NUMBER_OF_DATA +def test_vs_init_when_both_dataset_and_path_is_specified(local_path): + with pytest.raises(ValueError): + VectorStore( + path=local_path, + dataset=deeplake.empty(local_path, overwrite=True), + ) + + +def test_vs_init_when_both_dataset_and_path_are_not_specified(): + with pytest.raises(ValueError): + VectorStore() + + +def test_vs_init_with_emptyt_token(local_path): + with patch("deeplake.client.config.DEEPLAKE_AUTH_TOKEN", ""): + db = VectorStore( + path=local_path, + ) + + assert db.dataset_handler.username == "public" + + +@pytest.fixture +def mock_search_managed(mocker): + # Replace SearchManaged with a mock + mock_class = mocker.patch( + "deeplake.core.vectorstore.vector_search.indra.search_algorithm.SearchManaged" + ) + return mock_class + + +@pytest.fixture +def mock_search_indra(mocker): + # Replace SearchIndra with a mock + mock_class = mocker.patch( + "deeplake.core.vectorstore.vector_search.indra.search_algorithm.SearchIndra" + ) + return mock_class + + def test_vs_init_when_both_dataset_and_path_is_specified_should_throw_exception( local_path, ): @@ -2880,3 +2925,74 @@ def test_vs_init_with_emptyt_token_should_not_throw_exception(local_path): ) assert db.dataset_handler.username == "public" + + +@pytest.mark.slow +def test_db_search_with_managed_db_should_instantiate_SearchManaged_class( + mock_search_managed, hub_cloud_path, hub_cloud_dev_token +): + # using interaction test to ensure that the search managed class is executed + db = create_and_populate_vs( + hub_cloud_path, + runtime={"tensor_db": True}, + token=hub_cloud_dev_token, + ) + + # Perform the search + db.search(embedding=query_embedding) + + # Assert that SearchManaged was instantiated + mock_search_managed.assert_called() + + +@pytest.mark.slow +@requires_libdeeplake +def test_db_search_should_instantiate_SearchIndra_class( + mock_search_indra, hub_cloud_path, hub_cloud_dev_token +): + # using interaction test to ensure that the search indra class is executed + db = create_and_populate_vs( + hub_cloud_path, + token=hub_cloud_dev_token, + ) + + # Perform the search + db.search(embedding=query_embedding) + + # Assert that SearchIndra was instantiated + mock_search_indra.assert_called() + + +def returning_tql_for_exec_option_python_should_throw_exception(local_path): + db = VectorStore( + path=local_path, + ) + db.add(text=texts, embedding=embeddings, id=ids, metadata=metadatas) + + with pytest.raises(NotImplementedError): + db.search(embedding=query_embedding, return_tql=True) + + +def test_returning_tql_for_exec_option_compute_engine_should_return_correct_tql( + local_path, + hub_cloud_dev_token, +): + db = VectorStore( + path=local_path, + token=hub_cloud_dev_token, + ) + + texts, embeddings, ids, metadatas, _ = utils.create_data( + number_of_data=10, embedding_dim=3 + ) + + db.add(text=texts, embedding=embeddings, id=ids, metadata=metadatas) + + query_embedding = np.zeros(3, dtype=np.float32) + output = db.search(embedding=query_embedding, return_tql=True) + + assert output["tql"] == ( + "select text, metadata, id, score from " + "(select *, COSINE_SIMILARITY(embedding, ARRAY[0.0, 0.0, 0.0]) as score " + "order by COSINE_SIMILARITY(embedding, ARRAY[0.0, 0.0, 0.0]) DESC limit 4)" + ) diff --git a/deeplake/core/vectorstore/vector_search/dataset/dataset.py b/deeplake/core/vectorstore/vector_search/dataset/dataset.py index 2450718767..ffa3f88b0d 100644 --- a/deeplake/core/vectorstore/vector_search/dataset/dataset.py +++ b/deeplake/core/vectorstore/vector_search/dataset/dataset.py @@ -567,6 +567,7 @@ def convert_id_to_row_id(ids, dataset, search_fn, query, exec_option, filter): return_view=True, k=int(1e9), deep_memory=False, + return_tql=False, ) else: diff --git a/deeplake/core/vectorstore/vector_search/indra/search_algorithm.py b/deeplake/core/vectorstore/vector_search/indra/search_algorithm.py index 71a8a1f353..35b77ffbae 100644 --- a/deeplake/core/vectorstore/vector_search/indra/search_algorithm.py +++ b/deeplake/core/vectorstore/vector_search/indra/search_algorithm.py @@ -1,4 +1,5 @@ import numpy as np +from abc import ABC, abstractmethod from typing import Union, Dict, List, Optional from deeplake.core.vectorstore.vector_search.indra import query @@ -9,6 +10,157 @@ from deeplake.enterprise.util import raise_indra_installation_error +class SearchBasic(ABC): + def __init__( + self, + deeplake_dataset: DeepLakeDataset, + org_id: Optional[str] = None, + token: Optional[str] = None, + runtime: Optional[Dict] = None, + deep_memory: bool = False, + ): + """Base class for all search algorithms. + Args: + deeplake_dataset (DeepLakeDataset): DeepLake dataset object. + org_id (Optional[str], optional): Organization ID, is needed only for local datasets. Defaults to None. + token (Optional[str], optional): Token used for authentication. Defaults to None. + runtime (Optional[Dict], optional): Whether to run query on managed_db or indra. Defaults to None. + deep_memory (bool): Use DeepMemory for the search. Defaults to False. + """ + self.deeplake_dataset = deeplake_dataset + self.org_id = org_id + self.token = token + self.runtime = runtime + self.deep_memory = deep_memory + + def run( + self, + tql_string: str, + return_view: bool, + return_tql: bool, + distance_metric: str, + k: int, + query_embedding: np.ndarray, + embedding_tensor: str, + tql_filter: str, + return_tensors: List[str], + ): + tql_query = self._create_tql_string( + tql_string, + distance_metric, + k, + query_embedding, + embedding_tensor, + tql_filter, + return_tensors, + ) + view = self._get_view( + tql_query, + runtime=self.runtime, + ) + + if return_view: + return view + + return_data = self._collect_return_data(view) + + if return_tql: + return {"data": return_data, "tql": tql_query} + return return_data + + @abstractmethod + def _collect_return_data( + self, + view: DeepLakeDataset, + ): + pass + + @staticmethod + def _create_tql_string( + tql_string: str, + distance_metric: str, + k: int, + query_embedding: np.ndarray, + embedding_tensor: str, + tql_filter: str, + return_tensors: List[str], + ): + """Creates TQL query string for the vector search.""" + if tql_string: + return tql_string + else: + return query.parse_query( + distance_metric, + k, + query_embedding, + embedding_tensor, + tql_filter, + return_tensors, + ) + + @abstractmethod + def _get_view(self, tql_query: str, runtime: Optional[Dict] = None): + pass + + +class SearchIndra(SearchBasic): + def _get_view(self, tql_query, runtime: Optional[Dict] = None): + indra_dataset = self._get_indra_dataset() + indra_view = indra_dataset.query(tql_query) + view = DeepLakeQueryDataset( + deeplake_ds=self.deeplake_dataset, indra_ds=indra_view + ) + view._tql_query = tql_query + return view + + def _get_indra_dataset(self): + try: + from indra import api # type: ignore + + INDRA_INSTALLED = True + except ImportError: + INDRA_INSTALLED = False + pass + + if not INDRA_INSTALLED: + raise raise_indra_installation_error(indra_import_error=None) + + if self.deeplake_dataset.libdeeplake_dataset is not None: + indra_dataset = self.deeplake_dataset.libdeeplake_dataset + else: + if self.org_id is not None: + self.deeplake_dataset.org_id = self.org_id + if self.token is not None: + self.deeplake_dataset.set_token(self.token) + + indra_dataset = dataset_to_libdeeplake(self.deeplake_dataset) + return indra_dataset + + def _collect_return_data( + self, + view: DeepLakeDataset, + ): + return_data = {} + for tensor in view.tensors: + return_data[tensor] = utils.parse_tensor_return(view[tensor]) + return return_data + + +class SearchManaged(SearchBasic): + def _get_view(self, tql_query, runtime: Optional[Dict] = None): + view, data = self.deeplake_dataset.query( + tql_query, runtime=runtime, return_data=True + ) + self.data = data + return view + + def _collect_return_data( + self, + view: DeepLakeDataset, + ): + return self.data + + def search( query_embedding: np.ndarray, distance_metric: str, @@ -20,9 +172,9 @@ def search( runtime: dict, return_tensors: List[str], return_view: bool = False, - deep_memory: bool = False, token: Optional[str] = None, org_id: Optional[str] = None, + return_tql: bool = False, ) -> Union[Dict, DeepLakeDataset]: """Generalized search algorithm that uses indra. It combines vector search and other TQL queries. @@ -37,9 +189,9 @@ def search( runtime (dict): Runtime parameters for the query. return_tensors (List[str]): List of tensors to return data for. return_view (bool): Return a Deep Lake dataset view that satisfied the search parameters, instead of a dictinary with data. Defaults to False. - deep_memory (bool): Use DeepMemory for the search. Defaults to False. token (Optional[str], optional): Token used for authentication. Defaults to None. org_id (Optional[str], optional): Organization ID, is needed only for local datasets. Defaults to None. + return_tql (bool): Return TQL query used for the search. Defaults to False. Raises: ValueError: If both tql_string and tql_filter are specified. @@ -48,76 +200,20 @@ def search( Returns: Union[Dict, DeepLakeDataset]: Dictionary where keys are tensor names and values are the results of the search, or a Deep Lake dataset view. """ - try: - from indra import api # type: ignore - - INDRA_INSTALLED = True - except ImportError: - INDRA_INSTALLED = False - pass - - if tql_string: - tql_query = tql_string + searcher: SearchBasic + if runtime and runtime.get("db_engine", False): + searcher = SearchManaged(deeplake_dataset, org_id, token, runtime=runtime) else: - tql_query = query.parse_query( - distance_metric, - k, - query_embedding, - embedding_tensor, - tql_filter, - return_tensors, - ) - - if deep_memory: - if not INDRA_INSTALLED: - raise raise_indra_installation_error(indra_import_error=None) - - if deeplake_dataset.libdeeplake_dataset is not None: - indra_dataset = deeplake_dataset.libdeeplake_dataset - else: - if org_id is not None: - deeplake_dataset.org_id = org_id - if token is not None: - deeplake_dataset.set_token(token) - - indra_dataset = dataset_to_libdeeplake(deeplake_dataset) - api.tql.prepare_deepmemory_metrics(indra_dataset) - - indra_view = indra_dataset.query(tql_query) - - view = DeepLakeQueryDataset(deeplake_ds=deeplake_dataset, indra_ds=indra_view) - view._tql_query = tql_query - - if return_view: - return view - - return_data = {} - for tensor in view.tensors: - return_data[tensor] = utils.parse_tensor_return(view[tensor]) - elif runtime and runtime.get("db_engine", False): - view, data = deeplake_dataset.query( - tql_query, runtime=runtime, return_data=True - ) - if return_view: - return view - - return_data = data - else: - if not INDRA_INSTALLED: - raise raise_indra_installation_error( - indra_import_error=None - ) # pragma: no cover - - view = deeplake_dataset.query( - tql_query, - runtime=runtime, - ) - - if return_view: - return view - - return_data = {} - for tensor in view.tensors: - return_data[tensor] = utils.parse_tensor_return(view[tensor]) - - return return_data + searcher = SearchIndra(deeplake_dataset, org_id, token) + + return searcher.run( + tql_string=tql_string, + return_view=return_view, + return_tql=return_tql, + distance_metric=distance_metric, + k=k, + query_embedding=query_embedding, + embedding_tensor=embedding_tensor, + tql_filter=tql_filter, + return_tensors=return_tensors, + ) diff --git a/deeplake/core/vectorstore/vector_search/indra/vector_search.py b/deeplake/core/vectorstore/vector_search/indra/vector_search.py index b0e80cb019..96992567c6 100644 --- a/deeplake/core/vectorstore/vector_search/indra/vector_search.py +++ b/deeplake/core/vectorstore/vector_search/indra/vector_search.py @@ -18,9 +18,9 @@ def vector_search( k, return_tensors, return_view, - deep_memory, token, org_id, + return_tql, ) -> Union[Dict, DeepLakeDataset]: try: from indra import api # type: ignore @@ -55,7 +55,7 @@ def vector_search( runtime=runtime, return_tensors=return_tensors, return_view=return_view, - deep_memory=deep_memory, token=token, org_id=org_id, + return_tql=return_tql, ) diff --git a/deeplake/core/vectorstore/vector_search/python/test_vector_search.py b/deeplake/core/vectorstore/vector_search/python/test_vector_search.py index da18ec955f..848ca6a9a5 100644 --- a/deeplake/core/vectorstore/vector_search/python/test_vector_search.py +++ b/deeplake/core/vectorstore/vector_search/python/test_vector_search.py @@ -26,9 +26,9 @@ def test_vector_search(): k=10, return_tensors=[], return_view=False, - deep_memory=False, token=None, org_id=None, + return_tql=False, ) assert len(data["score"]) == 10 @@ -46,9 +46,9 @@ def test_vector_search(): k=10, return_tensors=[], return_view=False, - deep_memory=False, token=None, org_id=None, + return_tql=False, ) data = vector_search.vector_search( @@ -63,9 +63,9 @@ def test_vector_search(): k=10, return_tensors=[], return_view=True, - deep_memory=False, token=None, org_id=None, + return_tql=False, ) assert len(data) == 10 @@ -84,7 +84,7 @@ def test_vector_search(): k=10, return_tensors=[], return_view=True, - deep_memory=False, token=None, org_id=None, + return_tql=False, ) diff --git a/deeplake/core/vectorstore/vector_search/python/vector_search.py b/deeplake/core/vectorstore/vector_search/python/vector_search.py index e632b5e5a1..b1ecf1885b 100644 --- a/deeplake/core/vectorstore/vector_search/python/vector_search.py +++ b/deeplake/core/vectorstore/vector_search/python/vector_search.py @@ -18,15 +18,20 @@ def vector_search( k, return_tensors, return_view, - deep_memory, token, org_id, + return_tql, ) -> Union[Dict, DeepLakeDataset]: if query is not None: raise NotImplementedError( f"User-specified TQL queries are not supported for exec_option={exec_option} " ) + if return_tql: + raise NotImplementedError( + f"return_tql is not supported for exec_option={exec_option}" + ) + view = filter_utils.attribute_based_filtering_python(dataset, filter) return_data = {} diff --git a/deeplake/core/vectorstore/vector_search/utils.py b/deeplake/core/vectorstore/vector_search/utils.py index 7447cb87d2..6927a33571 100644 --- a/deeplake/core/vectorstore/vector_search/utils.py +++ b/deeplake/core/vectorstore/vector_search/utils.py @@ -190,11 +190,13 @@ def generate_json(value, key): return {key: value} -def create_data(number_of_data, embedding_dim=100, metadata_key="abc"): +def create_data( + number_of_data, embedding_dim=100, metadata_key="abc", string_length=1000 +): embeddings = np.random.uniform( low=-10, high=10, size=(number_of_data, embedding_dim) ).astype(np.float32) - texts = [generate_random_string(1000) for i in range(number_of_data)] + texts = [generate_random_string(string_length) for i in range(number_of_data)] ids = [f"{i}" for i in range(number_of_data)] metadata = [generate_json(i, metadata_key) for i in range(number_of_data)] images = ["deeplake/tests/dummy_data/images/car.jpg" for i in range(number_of_data)] diff --git a/deeplake/core/vectorstore/vector_search/vector_search.py b/deeplake/core/vectorstore/vector_search/vector_search.py index 4606fd1257..4e2b079676 100644 --- a/deeplake/core/vectorstore/vector_search/vector_search.py +++ b/deeplake/core/vectorstore/vector_search/vector_search.py @@ -28,9 +28,9 @@ def search( query_embedding: Optional[Union[List[float], np.ndarray]] = None, embedding_tensor: str = "embedding", return_view: bool = False, - deep_memory: bool = False, token: Optional[str] = None, org_id: Optional[str] = None, + return_tql: bool = False, ) -> Union[Dict, DeepLakeDataset]: """Searching function Args: @@ -50,9 +50,9 @@ def search( return_tensors (Optional[List[str]], optional): List of tensors to return data for. embedding_tensor (str): name of the tensor in the dataset with `htype="embedding"`. Defaults to "embedding". return_view (Bool): Return a Deep Lake dataset view that satisfied the search parameters, instead of a dictinary with data. Defaults to False. - deep_memory (bool): Use DeepMemory for the search. Defaults to False. token (Optional[str], optional): Token used for authentication. Defaults to None. org_id (Optional[str], optional): Organization ID, is needed only for local datasets. Defaults to None. + return_tql (bool): Return TQL query used for the search. Defaults to False. """ return EXEC_OPTION_TO_SEARCH_TYPE[exec_option]( query=query, @@ -66,7 +66,7 @@ def search( k=k, return_tensors=return_tensors, return_view=return_view, - deep_memory=deep_memory, token=token, org_id=org_id, + return_tql=return_tql, ) diff --git a/deeplake/requirements/tests.txt b/deeplake/requirements/tests.txt index d2f66c77a8..373fa63000 100644 --- a/deeplake/requirements/tests.txt +++ b/deeplake/requirements/tests.txt @@ -2,6 +2,7 @@ pytest pytest-cases pytest-benchmark pytest-cov +pytest-mock pytest-timeout pytest-rerunfailures pytest-profiling diff --git a/deeplake/tests/path_fixtures.py b/deeplake/tests/path_fixtures.py index 659ea30cd5..e192175c90 100644 --- a/deeplake/tests/path_fixtures.py +++ b/deeplake/tests/path_fixtures.py @@ -792,3 +792,75 @@ def precomputed_jobs_list(): with open(os.path.join(parent, "precomputed_jobs_list.txt"), "r") as f: jobs = f.read() return jobs + + +@pytest.fixture +def local_dmv2_dataset(request, hub_cloud_dev_token): + dmv2_path = f"hub://{HUB_CLOUD_DEV_USERNAME}/dmv2" + + local_cache_path = ".deepmemory_tests_cache/" + if not os.path.exists(local_cache_path): + os.mkdir(local_cache_path) + + dataset_cache_path = local_cache_path + "dmv2" + if not os.path.exists(dataset_cache_path): + deeplake.deepcopy( + dmv2_path, + dataset_cache_path, + token=hub_cloud_dev_token, + overwrite=True, + ) + + corpus = _get_storage_path(request, LOCAL) + + deeplake.deepcopy( + dataset_cache_path, + corpus, + token=hub_cloud_dev_token, + overwrite=True, + ) + yield corpus + + delete_if_exists(corpus, hub_cloud_dev_token) + + +@pytest.fixture +def deepmemory_small_dataset_copy(request, hub_cloud_dev_token): + dm_path = f"hub://{HUB_CLOUD_DEV_USERNAME}/tiny_dm_dataset" + queries_path = f"hub://{HUB_CLOUD_DEV_USERNAME}/queries_vs" + + local_cache_path = ".deepmemory_tests_cache/" + if not os.path.exists(local_cache_path): + os.mkdir(local_cache_path) + + dataset_cache_path = local_cache_path + "tiny_dm_queries" + if not os.path.exists(dataset_cache_path): + deeplake.deepcopy( + queries_path, + dataset_cache_path, + token=hub_cloud_dev_token, + overwrite=True, + ) + + corpus = _get_storage_path(request, HUB_CLOUD) + query_vs = VectorStore( + path=dataset_cache_path, + ) + queries = query_vs.dataset.text.data()["value"] + relevance = query_vs.dataset.metadata.data()["value"] + relevance = [rel["relevance"] for rel in relevance] + + deeplake.deepcopy( + dm_path, + corpus, + token=hub_cloud_dev_token, + overwrite=True, + runtime={"tensor_db": True}, + ) + + queries_path = corpus + "_eval_queries" + + yield corpus, queries, relevance, queries_path + + delete_if_exists(corpus, hub_cloud_dev_token) + delete_if_exists(queries_path, hub_cloud_dev_token)