Skip to content

Commit

Permalink
feat: implement max marginal relevance for momento vector index (lang…
Browse files Browse the repository at this point in the history
…chain-ai#13619)

**Description**

Implements `max_marginal_relevance_search` and
`max_marginal_relevance_search_by_vector` for the Momento Vector Index
vectorstore.

Additionally bumps the `momento` dependency in the lock file and adds
logging to the implementation.

**Dependencies**

✅ updates `momento` dependency in lock file

**Tag maintainer**

@baskaryan 

**Twitter handle**

Please tag @momentohq for Momento Vector Index and @mloml for the
contribution 🙇

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
  • Loading branch information
malandis authored and aymeric-roucher committed Dec 11, 2023
1 parent c7193da commit 170e29e
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 13 deletions.
94 changes: 90 additions & 4 deletions libs/langchain/langchain/vectorstores/momento_vector_index.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -11,15 +12,17 @@
)
from uuid import uuid4

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore

from langchain.utils import get_from_env
from langchain.vectorstores.utils import DistanceStrategy
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance

VST = TypeVar("VST", bound="VectorStore")

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from momento import PreviewVectorIndexClient
Expand Down Expand Up @@ -75,9 +78,8 @@ def __init__(
index_name (str, optional): The name of the index to store the documents in.
Defaults to "default".
distance_strategy (DistanceStrategy, optional): The distance strategy to
use. Defaults to DistanceStrategy.COSINE. If you select
DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses the squared
Euclidean distance.
use. If you select DistanceStrategy.EUCLIDEAN_DISTANCE, Momento uses
the squared Euclidean distance. Defaults to DistanceStrategy.COSINE.
text_field (str, optional): The name of the metadata field to store the
original text in. Defaults to "text".
ensure_index_exists (bool, optional): Whether to ensure that the index
Expand Down Expand Up @@ -125,6 +127,7 @@ def _create_index_if_not_exists(self, num_dimensions: int) -> bool:
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
similarity_metric = SimilarityMetric.EUCLIDEAN_SIMILARITY
else:
logger.error(f"Distance strategy {self.distance_strategy} not implemented.")
raise ValueError(
f"Distance strategy {self.distance_strategy} not implemented."
)
Expand All @@ -137,8 +140,10 @@ def _create_index_if_not_exists(self, num_dimensions: int) -> bool:
elif isinstance(response, CreateIndex.IndexAlreadyExists):
return False
elif isinstance(response, CreateIndex.Error):
logger.error(f"Error creating index: {response.inner_exception}")
raise response.inner_exception
else:
logger.error(f"Unexpected response: {response}")
raise Exception(f"Unexpected response: {response}")

def add_texts(
Expand Down Expand Up @@ -331,6 +336,87 @@ def similarity_search_by_vector(
)
return [doc for doc, _ in results]

def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
from momento.requests.vector_index import ALL_METADATA
from momento.responses.vector_index import SearchAndFetchVectors

response = self._client.search_and_fetch_vectors(
self.index_name, embedding, top_k=fetch_k, metadata_fields=ALL_METADATA
)

if isinstance(response, SearchAndFetchVectors.Success):
pass
elif isinstance(response, SearchAndFetchVectors.Error):
logger.error(f"Error searching and fetching vectors: {response}")
return []
else:
logger.error(f"Unexpected response: {response}")
raise Exception(f"Unexpected response: {response}")

mmr_selected = maximal_marginal_relevance(
query_embedding=np.array([embedding], dtype=np.float32),
embedding_list=[hit.vector for hit in response.hits],
lambda_mult=lambda_mult,
k=k,
)
selected = [response.hits[i].metadata for i in mmr_selected]
return [
Document(page_content=metadata.pop(self.text_field, ""), metadata=metadata) # type: ignore # noqa: E501
for metadata in selected
]

def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5.
Returns:
List of Documents selected by maximal marginal relevance.
"""
embedding = self._embedding.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, lambda_mult, **kwargs
)

@classmethod
def from_texts(
cls: Type[VST],
Expand Down
15 changes: 7 additions & 8 deletions libs/langchain/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_from_texts_with_metadatas(


def test_from_texts_with_scores(vector_store: MomentoVectorIndex) -> None:
# """Test end to end construction and search with scores and IDs."""
"""Test end to end construction and search with scores and IDs."""
texts = ["apple", "orange", "hammer"]
metadatas = [{"page": f"{i}"} for i in range(len(texts))]

Expand Down Expand Up @@ -162,3 +162,25 @@ def test_add_documents_with_ids(vector_store: MomentoVectorIndex) -> None:
)
assert isinstance(response, Search.Success)
assert [hit.id for hit in response.hits] == ids


def test_max_marginal_relevance_search(vector_store: MomentoVectorIndex) -> None:
"""Test max marginal relevance search."""
pepperoni_pizza = "pepperoni pizza"
cheese_pizza = "cheese pizza"
hot_dog = "hot dog"

vector_store.add_texts([pepperoni_pizza, cheese_pizza, hot_dog])
wait()
search_results = vector_store.similarity_search("pizza", k=2)

assert search_results == [
Document(page_content=pepperoni_pizza, metadata={}),
Document(page_content=cheese_pizza, metadata={}),
]

search_results = vector_store.max_marginal_relevance_search(query="pizza", k=2)
assert search_results == [
Document(page_content=pepperoni_pizza, metadata={}),
Document(page_content=hot_dog, metadata={}),
]

0 comments on commit 170e29e

Please sign in to comment.