Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion nemoguardrails/embeddings/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class BasicEmbeddingsIndex(EmbeddingsIndex):

embedding_model: str
embedding_engine: str
embedding_params: Dict[str, Any]
index: AnnoyIndex
embedding_size: int
cache_config: EmbeddingsCacheConfig
Expand All @@ -60,6 +61,7 @@ def __init__(
self,
embedding_model=None,
embedding_engine=None,
embedding_params=None,
index=None,
cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None,
search_threshold: float = None,
Expand All @@ -83,6 +85,7 @@ def __init__(
self._embeddings = []
self.embedding_model = embedding_model
self.embedding_engine = embedding_engine
self.embedding_params = embedding_params or {}
self._embedding_size = 0
self.search_threshold = search_threshold or float("inf")
if isinstance(cache_config, Dict):
Expand Down Expand Up @@ -132,7 +135,9 @@ def embeddings_index(self, index):
def _init_model(self):
"""Initialize the model used for computing the embeddings."""
self._model = init_embedding_model(
embedding_model=self.embedding_model, embedding_engine=self.embedding_engine
embedding_model=self.embedding_model,
embedding_engine=self.embedding_engine,
embedding_params=self.embedding_params,
)

@cache_embeddings
Expand Down
15 changes: 12 additions & 3 deletions nemoguardrails/embeddings/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,15 @@ def register_embedding_provider(
register_embedding_provider(nim.NVIDIAAIEndpointsEmbeddingModel)


def init_embedding_model(embedding_model: str, embedding_engine: str) -> EmbeddingModel:
def init_embedding_model(
embedding_model: str, embedding_engine: str, embedding_params: dict = {}
) -> EmbeddingModel:
"""Initialize the embedding model.

Args:
embedding_model (str): The path or name of the embedding model.
embedding_engine (str): The name of the embedding engine.
embedding_params (dict): Additional parameters for the embedding model.

Returns:
EmbeddingModel: An instance of the initialized embedding model.
Expand All @@ -84,10 +87,16 @@ def init_embedding_model(embedding_model: str, embedding_engine: str) -> Embeddi
ValueError: If the embedding engine is invalid.
"""

model_key = f"{embedding_engine}-{embedding_model}"
embedding_params_str = (
"_".join([f"{key}={value}" for key, value in embedding_params.items()])
or "default"
)

model_key = f"{embedding_engine}-{embedding_model}-{embedding_params_str}"

if model_key not in _embedding_model_cache:
model = EmbeddingProviderRegistry().get(embedding_engine)(embedding_model)
provider_class = EmbeddingProviderRegistry().get(embedding_engine)
model = provider_class(embedding_model=embedding_model, **embedding_params)
_embedding_model_cache[model_key] = model

return _embedding_model_cache[model_key]
6 changes: 3 additions & 3 deletions nemoguardrails/embeddings/providers/fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class FastEmbedEmbeddingModel(EmbeddingModel):

engine_name = "FastEmbed"

def __init__(self, embedding_model: str):
def __init__(self, embedding_model: str, **kwargs):
from fastembed import TextEmbedding as Embedding

# Enabling a short form model name for all-MiniLM-L6-v2.
Expand All @@ -50,13 +50,13 @@ def __init__(self, embedding_model: str):
self.embedding_model = embedding_model

try:
self.model = Embedding(embedding_model)
self.model = Embedding(embedding_model, **kwargs)
except ValueError as ex:
# Sometimes the cached model in the temporary folder gets removed,
# but the folder still exists, which causes an error. In this case,
# we fall back to an explicit cache directory.
if "Could not find model.onnx in" in str(ex):
self.model = Embedding(embedding_model, cache_dir=".cache")
self.model = Embedding(embedding_model, cache_dir=".cache", **kwargs)
else:
raise ex

Expand Down
4 changes: 2 additions & 2 deletions nemoguardrails/embeddings/providers/nim.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class NIMEmbeddingModel(EmbeddingModel):

engine_name = "nim"

def __init__(self, embedding_model: str):
def __init__(self, embedding_model: str, **kwargs):
try:
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings

self.model = embedding_model
self.document_embedder = NVIDIAEmbeddings(model=embedding_model)
self.document_embedder = NVIDIAEmbeddings(model=embedding_model, **kwargs)

except ImportError:
raise ImportError(
Expand Down
3 changes: 2 additions & 1 deletion nemoguardrails/embeddings/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class OpenAIEmbeddingModel(EmbeddingModel):
def __init__(
self,
embedding_model: str,
**kwargs,
):
try:
import openai
Expand All @@ -59,7 +60,7 @@ def __init__(
)

self.model = embedding_model
self.client = OpenAI()
self.client = OpenAI(**kwargs)

self.embedding_size_dict = {
"text-embedding-ada-002": 1536,
Expand Down
4 changes: 2 additions & 2 deletions nemoguardrails/embeddings/providers/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel):

engine_name = "SentenceTransformers"

def __init__(self, embedding_model: str):
def __init__(self, embedding_model: str, **kwargs):
try:
from sentence_transformers import SentenceTransformer
except ImportError:
Expand All @@ -58,7 +58,7 @@ def __init__(self, embedding_model: str):
)

device = "cuda" if cuda.is_available() else "cpu"
self.model = SentenceTransformer(embedding_model, device=device)
self.model = SentenceTransformer(embedding_model, device=device, **kwargs)
# Get the embedding dimension of the model
self.embedding_size = self.model.get_sentence_embedding_dimension()

Expand Down
5 changes: 5 additions & 0 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
# The default embeddings model is using FastEmbed
self.default_embedding_model = "all-MiniLM-L6-v2"
self.default_embedding_engine = "FastEmbed"
self.default_embedding_params = {}

# We keep a cache of the events history associated with a sequence of user messages.
# TODO: when we update the interface to allow to return a "state object", this
Expand Down Expand Up @@ -212,6 +213,7 @@ def __init__(
if model.type == "embeddings":
self.default_embedding_model = model.model
self.default_embedding_engine = model.engine
self.default_embedding_params = model.parameters or {}
break

# InteractionLogAdapters used for tracing
Expand Down Expand Up @@ -429,6 +431,9 @@ def _get_embeddings_search_provider_instance(
embedding_engine=esp_config.parameters.get(
"embedding_engine", self.default_embedding_engine
),
embedding_params=esp_config.parameters.get(
"embedding_parameters", self.default_embedding_params
),
cache_config=esp_config.cache,
# We make sure we also pass additional relevant params.
**{
Expand Down
161 changes: 161 additions & 0 deletions tests/test_embedding_providers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from typing import List

import pytest

from nemoguardrails.embeddings.providers import (
init_embedding_model,
register_embedding_provider,
)
from nemoguardrails.embeddings.providers.base import EmbeddingModel

SUPPORTED_PARAMS = {"param1", "param2"}


class MockEmbeddingModel(EmbeddingModel):
"""Mock embedding model for testing purposes.

Supported embedding models:
- mock-embedding-small: Embedding size of 128.
- mock-embedding-large: Embedding size of 256.
Supported parameters:
- param1
- param2

Args:
embedding_model (str): The name of the embedding model.

Attributes:
model (str): The name of the embedding model.
embedding_size (int): The size of the embeddings.

Methods:
encode: Encode a list of documents into embeddings.
"""

engine_name = "mock_engine"

def __init__(self, embedding_model: str, **kwargs):
self.model = embedding_model
self.embedding_size_dict = {
"mock-embedding-small": 128,
"mock-embedding-large": 256,
}

self.embedding_params = kwargs

if self.model not in self.embedding_size_dict:
raise ValueError(f"Invalid embedding model: {self.model}")

supported_params = SUPPORTED_PARAMS

for param in self.embedding_params:
if param not in supported_params:
raise ValueError(f"Unsupported parameter: {param}")

self.embedding_size = self.embedding_size_dict[self.model]

async def encode_async(self, documents: List[str]) -> List[List[float]]:
"""Encode a list of documents into embeddings asynchronously.

Args:
documents (List[str]): The list of documents to be encoded.

Returns:
List[List[float]]: The encoded embeddings.
"""
return await asyncio.get_running_loop().run_in_executor(
None, self.encode, documents
)

def encode(self, documents: List[str]) -> List[List[float]]:
"""Encode a list of documents into embeddings.

Args:
documents (List[str]): The list of documents to be encoded.

Returns:
List[List[float]]: The encoded embeddings.
"""
return [[float(i) for i in range(self.embedding_size)] for _ in documents]


register_embedding_provider(MockEmbeddingModel)


def test_init_embedding_model_with_params():
embedding_model = "mock-embedding-small"
embedding_engine = "mock_engine"
supported_param = next(iter(SUPPORTED_PARAMS))
embedding_params = {supported_param: "value1"}
model = init_embedding_model(embedding_model, embedding_engine, embedding_params)
assert isinstance(model, MockEmbeddingModel)
assert model.model == embedding_model
assert model.embedding_size == 128
assert model.engine_name == embedding_engine
assert model.embedding_params == embedding_params


def test_init_embedding_model_without_params():
embedding_model = "mock-embedding-large"
embedding_engine = "mock_engine"
model = init_embedding_model(embedding_model, embedding_engine)
assert isinstance(model, MockEmbeddingModel)
assert model.model == embedding_model
assert model.embedding_size == 256
assert model.engine_name == embedding_engine
assert model.embedding_params == {}


def test_init_embedding_model_with_unsupported_params():
embedding_model = "mock-embedding-small"
embedding_engine = "mock_engine"
embedding_params = {"unsupported_param": "value"}
with pytest.raises(ValueError, match="Unsupported parameter: unsupported_param"):
init_embedding_model(embedding_model, embedding_engine, embedding_params)


def test_init_embedding_model_with_invalid_model():
embedding_model = "invalid_model"
embedding_engine = "mock_engine"
embedding_params = {"param1": "value1"}
with pytest.raises(ValueError, match="Invalid embedding model: invalid_model"):
init_embedding_model(embedding_model, embedding_engine, embedding_params)


def test_encode_method():
embedding_model = "mock-embedding-small"
embedding_engine = "mock_engine"
model = init_embedding_model(embedding_model, embedding_engine)
assert isinstance(model, MockEmbeddingModel)
documents = ["doc1", "doc2", "doc3"]
embeddings = model.encode(documents)
assert len(embeddings) == len(documents)
assert len(embeddings[0]) == model.embedding_size


@pytest.mark.asyncio
async def test_encode_async_method():
embedding_model = "mock-embedding-large"
embedding_engine = "mock_engine"
model = init_embedding_model(embedding_model, embedding_engine)
assert isinstance(model, MockEmbeddingModel)
documents = ["doc1", "doc2", "doc3"]
embeddings = await model.encode_async(documents)
assert len(embeddings) == len(documents)
assert len(embeddings[0]) == model.embedding_size
7 changes: 7 additions & 0 deletions tests/test_embeddings_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ def test_sync_embeddings():
assert len(result[0]) == 384


def test_additional_params_with_fastembed():
model = FastEmbedEmbeddingModel("all-MiniLM-L6-v2", max_length=512, lazy_load=True)
result = model.encode(["test"])

assert len(result[0]) == 384


@pytest.mark.asyncio
async def test_async_embeddings():
model = FastEmbedEmbeddingModel("all-MiniLM-L6-v2")
Expand Down