Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,079 changes: 665 additions & 414 deletions pdm.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies = [
"python-magic~=0.4.27",
"python-dotenv==1.0.0",
# LLM Triad
"unstract-adapters~=0.15.1",
"unstract-adapters~=0.16.0",
"llama-index==0.10.28",
"tiktoken~=0.4.0",
"transformers==4.37.0",
Expand Down
2 changes: 1 addition & 1 deletion src/unstract/sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.26.1"
__version__ = "0.27.0"


def get_sdk_version():
Expand Down
8 changes: 4 additions & 4 deletions src/unstract/sdk/audit.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import requests
from llama_index.core.callbacks import CBEventType, TokenCountingHandler

Expand Down Expand Up @@ -26,7 +28,7 @@ def push_usage_data(
token_counter: TokenCountingHandler = None,
model_name: str = "",
event_type: CBEventType = None,
**kwargs,
kwargs: dict[Any, Any] = None,
) -> None:
"""Pushes the usage data to the platform service.

Expand Down Expand Up @@ -84,9 +86,7 @@ def push_usage_data(
headers = {"Authorization": f"Bearer {bearer_token}"}

try:
response = requests.post(
url, headers=headers, json=data, timeout=30
)
response = requests.post(url, headers=headers, json=data, timeout=30)
if response.status_code != 200:
self.stream_log(
log=(
Expand Down
5 changes: 5 additions & 0 deletions src/unstract/sdk/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,8 @@ class ToolSettingsKey:
EMBEDDING_ADAPTER_ID = "embeddingAdapterId"
VECTOR_DB_ADAPTER_ID = "vectorDbAdapterId"
X2TEXT_ADAPTER_ID = "x2TextAdapterId"
ADAPTER_INSTANCE_ID = "adapter_instance_id"
EMBEDDING_DIMENSION = "embedding_dimension"
RUN_ID = "run_id"
WORKFLOW_ID = "workflow_id"
EXECUTION_ID = "execution_id"
63 changes: 49 additions & 14 deletions src/unstract/sdk/embedding.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,44 @@
from typing import Any

from llama_index.core.base.embeddings.base import Embedding
from llama_index.core.embeddings import BaseEmbedding
from typing_extensions import deprecated
from unstract.adapters.constants import Common
from unstract.adapters.embedding import adapters

from unstract.sdk.adapters import ToolAdapter
from unstract.sdk.constants import LogLevel
from unstract.sdk.exceptions import SdkError, ToolEmbeddingError
from unstract.sdk.constants import LogLevel, ToolEnv
from unstract.sdk.exceptions import EmbeddingError, SdkError
from unstract.sdk.tool.base import BaseTool
from unstract.sdk.utils.callback_manager import CallbackManager


class Embedding:
_TEST_SNIPPET = "Hello, I am Unstract"
MAX_TOKENS = 1024 * 16
embedding_adapters = adapters

class ToolEmbedding:
__TEST_SNIPPET = "Hello, I am Unstract"
def __init__(
self,
tool: BaseTool,
adapter_instance_id: str,
usage_kwargs: dict[Any, Any] = None,
):
self._tool = tool
self._adapter_instance_id = adapter_instance_id
self._embedding_instance: BaseEmbedding = self._get_embedding()
self._length: int = self._get_embedding_length()

def __init__(self, tool: BaseTool):
self.tool = tool
self.max_tokens = 1024 * 16
self.embedding_adapters = adapters
self._usage_kwargs = usage_kwargs.copy()
self._usage_kwargs["adapter_instance_id"] = adapter_instance_id
platform_api_key = self._tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY)
CallbackManager.set_callback_manager(
platform_api_key=platform_api_key,
model=self._embedding_instance,
kwargs=self._usage_kwargs,
)

def get_embedding(self, adapter_instance_id: str) -> BaseEmbedding:
def _get_embedding(self) -> BaseEmbedding:
"""Gets an instance of LlamaIndex's embedding object.

Args:
Expand All @@ -27,7 +49,7 @@ def get_embedding(self, adapter_instance_id: str) -> BaseEmbedding:
"""
try:
embedding_config_data = ToolAdapter.get_adapter_config(
self.tool, adapter_instance_id
self._tool, self._adapter_instance_id
)
embedding_adapter_id = embedding_config_data.get(Common.ADAPTER_ID)
if embedding_adapter_id not in self.embedding_adapters:
Expand All @@ -42,12 +64,25 @@ def get_embedding(self, adapter_instance_id: str) -> BaseEmbedding:
embedding_adapter_class = embedding_adapter(embedding_metadata)
return embedding_adapter_class.get_embedding_instance()
except Exception as e:
self.tool.stream_log(
self._tool.stream_log(
log=f"Error getting embedding: {e}", level=LogLevel.ERROR
)
raise ToolEmbeddingError(f"Error getting embedding instance: {e}") from e
raise EmbeddingError(f"Error getting embedding instance: {e}") from e

def get_embedding_length(self, embedding: BaseEmbedding) -> int:
embedding_list = embedding._get_text_embedding(self.__TEST_SNIPPET)
def get_query_embedding(self, query: str) -> Embedding:
return self._embedding_instance.get_query_embedding(query)

def _get_embedding_length(self) -> int:
embedding_list = self._embedding_instance._get_text_embedding(
self._TEST_SNIPPET
)
embedding_dimension = len(embedding_list)
return embedding_dimension

@deprecated("Use the new class Embedding")
def get_embedding_length(self, embedding: BaseEmbedding) -> int:
return self._get_embedding_length(embedding)


# Legacy
ToolEmbedding = Embedding
6 changes: 3 additions & 3 deletions src/unstract/sdk/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ def __init__(self, message: str = ""):
super().__init__(message)


class ToolLLMError(SdkError):
class LLMError(SdkError):
DEFAULT_MESSAGE = "Error ocurred related to LLM"


class ToolEmbeddingError(SdkError):
class EmbeddingError(SdkError):
DEFAULT_MESSAGE = "Error ocurred related to embedding"


class ToolVectorDBError(SdkError):
class VectorDBError(SdkError):
DEFAULT_MESSAGE = "Error ocurred related to vector DB"


Expand Down
Loading