Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add GigaChat Encoder #332

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
167 changes: 167 additions & 0 deletions docs/encoders/gigachat.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "# Using GigaChatEncoder",
"id": "35d3b3544b0b2bf5"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Getting Started",
"id": "8a04e30ad27664cb"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We start by installing semantic-router.",
"id": "e15f40cfbd181277"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "!pip install -qU semantic-router",
"id": "a22753e184585d66"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We start by defining a dictionary mapping routes to example phrases that should trigger those routes.",
"id": "c6ab1caebff2d748"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"from semantic_router import Route\n",
"\n",
"politics = Route(\n",
" name=\"politics\",\n",
" utterances=[\n",
" \"isn't politics the best thing ever\",\n",
" \"why don't you tell me about your political opinions\",\n",
" \"don't you just love the president\",\n",
" \"don't you just hate the president\",\n",
" \"they're going to destroy this country!\",\n",
" \"they will save the country!\",\n",
" ],\n",
")"
],
"id": "1387c6a6b2399cbb"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Let's define another for good measure:",
"id": "d14c31bb9ba0a2cf"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"chitchat = Route(\n",
" name=\"chitchat\",\n",
" utterances=[\n",
" \"how's the weather today?\",\n",
" \"how are things going?\",\n",
" \"lovely weather today\",\n",
" \"the weather is horrendous\",\n",
" \"let's go to the chippy\",\n",
" ],\n",
")\n",
"\n",
"routes = [politics, chitchat]"
],
"id": "9433a9a4d8420d4a"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now we initialize our embedding model.",
"id": "ebb87de5d9181b90"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"from semantic_router.encoders import GigaChatEncoder\n",
"\n",
"auth_data = \"your-auth-data\"\n",
"scope = \"your-scope\" # optional \"GIGACHAT_API_CORP\" or \"GIGACHAT_API_PERS\"\n",
"\n",
"encoder = GigaChatEncoder(auth_data=auth_data, scope=scope)"
],
"id": "954563c1102f8f5d"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now we define the RouteLayer. When called, the route layer will consume text (a query) and output the category (Route) it belongs to — to initialize a RouteLayer we need our encoder model and a list of routes.",
"id": "580ba91ad0dce419"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"from semantic_router.layer import RouteLayer\n",
"\n",
"rl = RouteLayer(encoder=encoder, routes=routes)"
],
"id": "7db9e2ea9afdf0ec"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now we can test it:",
"id": "6b456a5153ec37e7"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "rl(\"don't you love politics?\")",
"id": "c552767d54a45455"
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "rl(\"how's the weather today?\")",
"id": "b5e95b8cd6b009c3"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
4 changes: 4 additions & 0 deletions semantic_router/encoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from semantic_router.encoders.tfidf import TfidfEncoder
from semantic_router.encoders.vit import VitEncoder
from semantic_router.encoders.zure import AzureOpenAIEncoder
from semantic_router.encoders.gigachat import GigaChatEncoder
from semantic_router.schema import EncoderType

__all__ = [
Expand All @@ -31,6 +32,7 @@
"CLIPEncoder",
"GoogleEncoder",
"BedrockEncoder",
"GigaChatEncoder",
]


Expand Down Expand Up @@ -71,6 +73,8 @@ def __init__(self, type: str, name: Optional[str]):
self.model = GoogleEncoder(name=name)
elif self.type == EncoderType.BEDROCK:
self.model = BedrockEncoder(name=name) # type: ignore
elif self.type == EncoderType.GIGACHAT:
self.model = GigaChatEncoder(name=name)
else:
raise ValueError(f"Encoder type '{type}' not supported")

Expand Down
72 changes: 72 additions & 0 deletions semantic_router/encoders/gigachat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
from time import sleep
from typing import Any, List, Optional

from gigachat import GigaChat # Install the GigaChat API library via 'pip install gigachat'

from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault

class GigaChatEncoder(BaseEncoder):
"""GigaChat encoder class for generating embeddings.

Attributes:
client (Optional[Any]): Instance of the GigaChat client.
type (str): Type identifier for the encoder, which is "gigachat".
"""

client: Optional[Any] = None
type: str = "gigachat"

def __init__(self,
name: Optional[str] = None,
auth_data: Optional[str] = None,
scope: Optional[str] = None,
score_threshold: float = 0.75
):
"""Initializes the GigaChatEncoder.

Args:
name (Optional[str]): Name of the encoder model.
auth_data (Optional[str]): Authorization data for GigaChat.
scope (Optional[str]): Scope of the GigaChat API usage.
score_threshold (float): Threshold for scoring embeddings.

Raises:
ValueError: If auth_data or scope is None.
"""
if name is None:
name = EncoderDefault.GIGACHAT.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold)
auth_data = auth_data or os.getenv("GIGACHAT_AUTH_DATA")
if auth_data is None:
raise ValueError("GigaChat authorization data cannot be 'None'.")
if scope is None:
raise ValueError("GigaChat scope cannot be 'None'. Set 'GIGACHAT_API_PERS' for personal use or 'GIGACHAT_API_CORP' for corporate use.")
try:
self.client = GigaChat(scope=scope, credentials=auth_data, verify_ssl_certs=False)
except Exception as e:
raise ValueError(
f"GigaChat client failed to initialize. Error: {e}"
) from e

def __call__(self, docs: List[str]) -> List[List[float]]:
"""Generates embeddings for a list of documents.

Args:
docs: List of documents to generate embeddings for.

Returns:
List: List of embeddings for each document.

Raises:
ValueError: If the client is not initialized or the GigaChat call fails.
"""
if self.client is None:
raise ValueError("GigaChat client is not initialized.")
try:
embeddings = self.client.embeddings(docs).data
embeddings = [embeds_obj.embedding for embeds_obj in embeddings]
return embeddings
except Exception as e:
raise ValueError(f"GigaChat call failed. Error: {e}") from e
2 changes: 2 additions & 0 deletions semantic_router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class EncoderType(Enum):
CLIP = "clip"
GOOGLE = "google"
BEDROCK = "bedrock"
GIGACHAT = "gigachat"



class EncoderInfo(BaseModel):
Expand Down
5 changes: 5 additions & 0 deletions semantic_router/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ class EncoderDefault(Enum):
"BEDROCK_EMBEDDING_MODEL", "amazon.titan-embed-image-v1"
)
}
GIGACHAT = {
"embedding_model": os.getenv(
"GIGACHAT_EMBEDDING_MODEL", "Embeddings"
)
}
57 changes: 57 additions & 0 deletions tests/unit/encoders/test_gigachat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pytest
from unittest.mock import patch, Mock, MagicMock
from semantic_router.encoders import GigaChatEncoder

@pytest.fixture
def gigachat_encoder(mocker):
mocker.patch("gigachat.GigaChat")
return GigaChatEncoder(auth_data="test_auth_data", scope="GIGACHAT_API_PERS")

class TestGigaChatEncoder:
def test_gigachat_encoder_init_success(self):
encoder = GigaChatEncoder(auth_data="test_auth_data", scope="GIGACHAT_API_PERS")
assert encoder.client is not None
assert encoder.type == "gigachat"

@patch('os.getenv', return_value=None)
def test_gigachat_encoder_init_no_auth_data(self, mock_getenv):
with pytest.raises(ValueError) as e:
GigaChatEncoder(scope="GIGACHAT_API_PERS")
assert str(e.value) == "GigaChat authorization data cannot be 'None'."

def test_gigachat_encoder_init_no_scope(self):
with pytest.raises(ValueError) as e:
GigaChatEncoder(auth_data="test_auth_data")
assert str(
e.value) == "GigaChat scope cannot be 'None'. Set 'GIGACHAT_API_PERS' for personal use or 'GIGACHAT_API_CORP' for corporate use."

def test_gigachat_encoder_call_uninitialized_client(self, gigachat_encoder):
gigachat_encoder.client = None
with pytest.raises(ValueError) as e:
gigachat_encoder(["test document"])
assert "GigaChat client is not initialized." in str(e.value)

def test_gigachat_encoder_call_success(self, gigachat_encoder, mocker):
mock_embeddings = Mock()
mock_embeddings.data = [Mock(embedding=[0.1, 0.2])]

mocker.patch("time.sleep", return_value=None)

mocker.patch.object(gigachat_encoder.client, 'embeddings', return_value=mock_embeddings)

embeddings = gigachat_encoder(["test document"])

assert embeddings == [[0.1, 0.2]]

gigachat_encoder.client.embeddings.assert_called_with(["test document"])

def test_call_method_api_failure(self, gigachat_encoder):
gigachat_encoder.client.embeddings = MagicMock(side_effect=Exception("API failure"))
docs = ["document1", "document2"]
with pytest.raises(ValueError, match="GigaChat call failed. Error: API failure"):
gigachat_encoder(docs)

def test_init_failure_no_env_vars(self):
with pytest.raises(ValueError) as excinfo:
GigaChatEncoder()
assert "GigaChat authorization data cannot be 'None'" in str(excinfo.value)