Skip to content

Commit

Permalink
community[patch]: Voyage AI updates default model and batch size (lan…
Browse files Browse the repository at this point in the history
…gchain-ai#17655)

- **Description:** update the default model and batch size in
VoyageEmbeddings
    - **Issue:** N/A
    - **Dependencies:** N/A
    - **Twitter handle:** N/A

---------

Co-authored-by: fodizoltan <zoltan@conway.expert>
  • Loading branch information
2 people authored and gkorland committed Mar 30, 2024
1 parent 574024b commit b3b4cb2
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 15 deletions.
6 changes: 4 additions & 2 deletions docs/docs/integrations/text_embedding/voyageai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"id": "137cfde9-b88c-409a-9394-a9e31a6bf30d",
"metadata": {},
"source": [
"Voyage AI utilizes API keys to monitor usage and manage permissions. To obtain your key, create an account on our [homepage](https://www.voyageai.com). Then, create a VoyageEmbeddings model with your API key."
"Voyage AI utilizes API keys to monitor usage and manage permissions. To obtain your key, create an account on our [homepage](https://www.voyageai.com). Then, create a VoyageEmbeddings model with your API key. Please refer to the documentation for further details on the available models: https://docs.voyageai.com/embeddings/"
]
},
{
Expand All @@ -37,7 +37,9 @@
"metadata": {},
"outputs": [],
"source": [
"embeddings = VoyageEmbeddings(voyage_api_key=\"[ Your Voyage API key ]\")"
"embeddings = VoyageEmbeddings(\n",
" voyage_api_key=\"[ Your Voyage API key ]\", model=\"voyage-2\"\n",
")"
]
},
{
Expand Down
42 changes: 30 additions & 12 deletions libs/community/langchain_community/embeddings/voyageai.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ class VoyageEmbeddings(BaseModel, Embeddings):
from langchain_community.embeddings import VoyageEmbeddings
voyage = VoyageEmbeddings(voyage_api_key="your-api-key")
voyage = VoyageEmbeddings(voyage_api_key="your-api-key", model="voyage-2")
text = "This is a test query."
query_result = voyage.embed_query(text)
"""

model: str = "voyage-01"
model: str
voyage_api_base: str = "https://api.voyageai.com/v1/embeddings"
voyage_api_key: Optional[SecretStr] = None
batch_size: int = 8
batch_size: int
"""Maximum number of texts to embed in each API request."""
max_retries: int = 6
"""Maximum number of retries to make when generating."""
Expand All @@ -86,15 +86,12 @@ class VoyageEmbeddings(BaseModel, Embeddings):
show_progress_bar: bool = False
"""Whether to show a progress bar when embedding. Must have tqdm installed if set
to True."""
truncation: Optional[bool] = None
truncation: bool = True
"""Whether to truncate the input texts to fit within the context length.
If True, over-length input texts will be truncated to fit within the context
length, before vectorized by the embedding model. If False, an error will be
raised if any given text exceeds the context length. If not specified
(defaults to None), we will truncate the input text before sending it to the
embedding model if it slightly exceeds the context window length. If it
significantly exceeds the context window length, an error will be raised."""
raised if any given text exceeds the context length."""

class Config:
"""Configuration for this pydantic object."""
Expand All @@ -107,6 +104,22 @@ def validate_environment(cls, values: Dict) -> Dict:
values["voyage_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "voyage_api_key", "VOYAGE_API_KEY")
)

if "model" not in values:
values["model"] = "voyage-01"
logger.warning(
"model will become a required arg for VoyageAIEmbeddings, "
"we recommend to specify it when using this class. "
"Currently the default is set to voyage-01."
)

if "batch_size" not in values:
values["batch_size"] = (
72
if "model" in values and (values["model"] in ["voyage-2", "voyage-02"])
else 7
)

return values

def _invocation_params(
Expand All @@ -116,11 +129,14 @@ def _invocation_params(
params: Dict = {
"url": self.voyage_api_base,
"headers": {"Authorization": f"Bearer {api_key}"},
"json": {"model": self.model, "input": input, "input_type": input_type},
"json": {
"model": self.model,
"input": input,
"input_type": input_type,
"truncation": self.truncation,
},
"timeout": self.request_timeout,
}
if self.truncation is not None:
params["json"]["truncation"] = self.truncation
return params

def _get_embeddings(
Expand Down Expand Up @@ -186,7 +202,9 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
return self._get_embeddings([text], input_type="query")[0]
return self._get_embeddings(
[text], batch_size=self.batch_size, input_type="query"
)[0]

def embed_general_texts(
self, texts: List[str], *, input_type: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from langchain_community.embeddings.voyageai import VoyageEmbeddings

# Please set VOYAGE_API_KEY in the environment variables
MODEL = "voyage-01"
MODEL = "voyage-2"


def test_voyagi_embedding_documents() -> None:
Expand All @@ -14,10 +14,22 @@ def test_voyagi_embedding_documents() -> None:
assert len(output[0]) == 1024


def test_voyagi_with_default_model() -> None:
"""Test voyage embeddings."""
embedding = VoyageEmbeddings()
assert embedding.model == "voyage-01"
assert embedding.batch_size == 7
documents = [f"foo bar {i}" for i in range(72)]
output = embedding.embed_documents(documents)
assert len(output) == 72
assert len(output[0]) == 1024


def test_voyage_embedding_documents_multiple() -> None:
"""Test voyage embeddings."""
documents = ["foo bar", "bar foo", "foo"]
embedding = VoyageEmbeddings(model=MODEL, batch_size=2)
assert embedding.model == MODEL
output = embedding.embed_documents(documents)
assert len(output) == 3
assert len(output[0]) == 1024
Expand Down

0 comments on commit b3b4cb2

Please sign in to comment.