diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7c8f98e9..b5ea5f5c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,7 +50,7 @@ repos: rev: 24.2.0 hooks: - id: black - args: [--config=pyproject.toml, -l 80] + args: [--config=pyproject.toml, -l 88] language: system exclude: | (?x)^( @@ -60,7 +60,7 @@ repos: rev: 7.0.0 hooks: - id: flake8 - args: [--max-line-length=80] + args: [--max-line-length=88] exclude: | (?x)^( .*migrations/.*\.py| diff --git a/pyproject.toml b/pyproject.toml index db114345..3f027330 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ lint = [ ] [tool.isort] -line_length = 80 +line_length = 88 multi_line_output = 3 include_trailing_comma = true force_grid_wrap = 0 diff --git a/src/unstract/sdk/__init__.py b/src/unstract/sdk/__init__.py index dff86478..5c5e7053 100644 --- a/src/unstract/sdk/__init__.py +++ b/src/unstract/sdk/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.23.0" +__version__ = "0.24.0" def get_sdk_version(): diff --git a/src/unstract/sdk/constants.py b/src/unstract/sdk/constants.py index 9a1ee282..451fab02 100644 --- a/src/unstract/sdk/constants.py +++ b/src/unstract/sdk/constants.py @@ -146,7 +146,3 @@ class ToolSettingsKey: EMBEDDING_ADAPTER_ID = "embeddingAdapterId" VECTOR_DB_ADAPTER_ID = "vectorDbAdapterId" X2TEXT_ADAPTER_ID = "x2TextAdapterId" - - -class FileReaderSettings: - FILE_READER_CHUNK_SIZE = 8192 diff --git a/src/unstract/sdk/embedding.py b/src/unstract/sdk/embedding.py index 05496838..79238241 100644 --- a/src/unstract/sdk/embedding.py +++ b/src/unstract/sdk/embedding.py @@ -1,11 +1,9 @@ -from typing import Optional - from llama_index.core.embeddings import BaseEmbedding from unstract.adapters.constants import Common from unstract.adapters.embedding import adapters from unstract.sdk.adapters import ToolAdapter -from unstract.sdk.constants import LogLevel, ToolSettingsKey +from unstract.sdk.constants import LogLevel from unstract.sdk.exceptions import SdkError from unstract.sdk.tool.base import BaseTool @@ -13,48 +11,39 @@ class ToolEmbedding: __TEST_SNIPPET = "Hello, I am Unstract" - def __init__(self, tool: BaseTool, tool_settings: dict[str, str] = {}): + def __init__(self, tool: BaseTool): self.tool = tool self.max_tokens = 1024 * 16 self.embedding_adapters = adapters - self.embedding_adapter_instance_id = tool_settings.get( - ToolSettingsKey.EMBEDDING_ADAPTER_ID - ) - self.embedding_adapter_id: Optional[str] = None - def get_embedding( - self, adapter_instance_id: Optional[str] = None - ) -> BaseEmbedding: - adapter_instance_id = ( - adapter_instance_id - if adapter_instance_id - else self.embedding_adapter_instance_id - ) - if not adapter_instance_id: - raise SdkError( - f"Adapter_instance_id does not have " - f"a valid value: {adapter_instance_id}" - ) + def get_embedding(self, adapter_instance_id: str) -> BaseEmbedding: + """Gets an instance of LlamaIndex's embedding object. + + Args: + adapter_instance_id (str): UUID of the embedding adapter + + Returns: + BaseEmbedding: Embedding instance + """ try: embedding_config_data = ToolAdapter.get_adapter_config( self.tool, adapter_instance_id ) embedding_adapter_id = embedding_config_data.get(Common.ADAPTER_ID) - self.embedding_adapter_id = embedding_adapter_id - if embedding_adapter_id in self.embedding_adapters: - embedding_adapter = self.embedding_adapters[ - embedding_adapter_id - ][Common.METADATA][Common.ADAPTER] - embedding_metadata = embedding_config_data.get( - Common.ADAPTER_METADATA - ) - embedding_adapter_class = embedding_adapter(embedding_metadata) - return embedding_adapter_class.get_embedding_instance() - else: + if embedding_adapter_id not in self.embedding_adapters: raise SdkError( f"Embedding adapter not supported : " f"{embedding_adapter_id}" ) + + embedding_adapter = self.embedding_adapters[embedding_adapter_id][ + Common.METADATA + ][Common.ADAPTER] + embedding_metadata = embedding_config_data.get( + Common.ADAPTER_METADATA + ) + embedding_adapter_class = embedding_adapter(embedding_metadata) + return embedding_adapter_class.get_embedding_instance() except Exception as e: self.tool.stream_log( log=f"Error getting embedding: {e}", level=LogLevel.ERROR diff --git a/src/unstract/sdk/index.py b/src/unstract/sdk/index.py index c615b56b..59aa38bc 100644 --- a/src/unstract/sdk/index.py +++ b/src/unstract/sdk/index.py @@ -1,3 +1,4 @@ +import json from typing import Optional from llama_index.core import Document @@ -14,14 +15,13 @@ from unstract.adapters.exceptions import AdapterError from unstract.adapters.x2text.x2text_adapter import X2TextAdapter +from unstract.sdk.adapters import ToolAdapter from unstract.sdk.constants import LogLevel, ToolEnv from unstract.sdk.embedding import ToolEmbedding from unstract.sdk.exceptions import IndexingError, SdkError from unstract.sdk.tool.base import BaseTool from unstract.sdk.utils import ToolUtils -from unstract.sdk.utils.callback_manager import ( - CallbackManager as UNCallbackManager, -) +from unstract.sdk.utils.callback_manager import CallbackManager as UNCallbackManager from unstract.sdk.vector_db import ToolVectorDB from unstract.sdk.x2txt import X2Text @@ -31,18 +31,9 @@ def __init__(self, tool: BaseTool): # TODO: Inherit from StreamMixin and avoid using BaseTool self.tool = tool - def get_text_from_index( - self, embedding_type: str, vector_db: str, doc_id: str - ): + def get_text_from_index(self, embedding_type: str, vector_db: str, doc_id: str): embedd_helper = ToolEmbedding(tool=self.tool) - embedding_li = embedd_helper.get_embedding( - adapter_instance_id=embedding_type - ) - if embedding_li is None: - self.tool.stream_log( - f"Error loading {embedding_type}", level=LogLevel.ERROR - ) - raise SdkError(f"Error loading {embedding_type}") + embedding_li = embedd_helper.get_embedding(adapter_instance_id=embedding_type) embedding_dimension = embedd_helper.get_embedding_length(embedding_li) vdb_helper = ToolVectorDB( @@ -53,12 +44,6 @@ def get_text_from_index( embedding_dimension=embedding_dimension, ) - if vector_db_li is None: - self.tool.stream_log( - f"Error loading {vector_db}", level=LogLevel.ERROR - ) - raise SdkError(f"Error loading {vector_db}") - try: self.tool.stream_log(f">>> Querying {vector_db}...") self.tool.stream_log(f">>> {doc_id}") @@ -149,48 +134,33 @@ def index_file( Returns: str: A unique ID for the file and indexing arguments combination """ - # Make file content hash if not available - if not file_hash: - file_hash = ToolUtils.get_hash_from_file(file_path=file_path) - - doc_id = ToolIndex.generate_file_id( + doc_id = self.generate_file_id( tool_id=tool_id, - file_hash=file_hash, vector_db=vector_db, embedding=embedding_type, x2text=x2text_adapter, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, + chunk_size=str(chunk_size), + chunk_overlap=str(chunk_overlap), + file_path=file_path, + file_hash=file_hash, ) - self.tool.stream_log(f"Checking if doc_id {doc_id} exists") - vdb_helper = ToolVectorDB( - tool=self.tool, - ) - + # Get embedding instance embedd_helper = ToolEmbedding(tool=self.tool) + embedding_li = embedd_helper.get_embedding(adapter_instance_id=embedding_type) + embedding_dimension = embedd_helper.get_embedding_length(embedding_li) - embedding_li = embedd_helper.get_embedding( - adapter_instance_id=embedding_type + # Get vectorDB instance + vdb_helper = ToolVectorDB( + tool=self.tool, ) - if embedding_li is None: - self.tool.stream_log( - f"Error loading {embedding_type}", level=LogLevel.ERROR - ) - raise SdkError(f"Error loading {embedding_type}") - - embedding_dimension = embedd_helper.get_embedding_length(embedding_li) vector_db_li = vdb_helper.get_vector_db( adapter_instance_id=vector_db, embedding_dimension=embedding_dimension, ) - if vector_db_li is None: - self.tool.stream_log( - f"Error loading {vector_db}", level=LogLevel.ERROR - ) - raise SdkError(f"Error loading {vector_db}") + # Checking if document is already indexed against doc_id doc_id_eq_filter = MetadataFilter.from_dict( {"key": "doc_id", "operator": FilterOperator.EQ, "value": doc_id} ) @@ -275,26 +245,20 @@ def index_file( parser = SimpleNodeParser.from_defaults( chunk_size=len(documents[0].text) + 10, chunk_overlap=0 ) - nodes = parser.get_nodes_from_documents( - documents, show_progress=True - ) + nodes = parser.get_nodes_from_documents(documents, show_progress=True) node = nodes[0] node.embedding = embedding_li.get_query_embedding(" ") vector_db_li.add(nodes=[node]) self.tool.stream_log("Added node to vector db") else: - storage_context = StorageContext.from_defaults( - vector_store=vector_db_li - ) + storage_context = StorageContext.from_defaults(vector_store=vector_db_li) parser = SimpleNodeParser.from_defaults( chunk_size=chunk_size, chunk_overlap=chunk_overlap ) # Set callback_manager to collect Usage stats callback_manager = UNCallbackManager.set_callback_manager( - platform_api_key=self.tool.get_env_or_die( - ToolEnv.PLATFORM_API_KEY - ), + platform_api_key=self.tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY), embedding=embedding_li, ) @@ -319,31 +283,53 @@ def index_file( self.tool.stream_log("File has been indexed successfully") return doc_id - @staticmethod def generate_file_id( + self, tool_id: str, - file_hash: str, vector_db: str, embedding: str, x2text: str, chunk_size: str, chunk_overlap: str, + file_path: Optional[str] = None, + file_hash: Optional[str] = None, ) -> str: """Generates a unique ID useful for identifying files during indexing. Args: - tool_id (str): Unique ID of the tool developed / exported - file_hash (str): Hash of the file contents + tool_id (str): Unique ID of the tool or workflow vector_db (str): UUID of the vector DB adapter embedding (str): UUID of the embedding adapter x2text (str): UUID of the X2Text adapter chunk_size (str): Chunk size for indexing chunk_overlap (str): Chunk overlap for indexing + file_path (Optional[str]): Path to the file that needs to be indexed. + Defaults to None. One of file_path or file_hash needs to be specified. + file_hash (Optional[str], optional): SHA256 hash of the file. + Defaults to None. If None, the hash is generated with file_path. Returns: str: Key representing unique ID for a file """ - return ( - f"{tool_id}|{vector_db}|{embedding}|{x2text}|" - f"{chunk_size}|{chunk_overlap}|{file_hash}" - ) + if not file_path and not file_hash: + raise ValueError("One of `file_path` or `file_hash` need to be provided") + + if not file_hash: + file_hash = ToolUtils.get_hash_from_file(file_path=file_path) + + # Whole adapter config is used currently even though it contains some keys + # which might not be relevant to indexing. This is easier for now than + # marking certain keys of the adapter config as necessary. + index_key = { + "tool_id": tool_id, + "file_hash": file_hash, + "vector_db_config": ToolAdapter.get_adapter_config(self.tool, vector_db), + "embedding_config": ToolAdapter.get_adapter_config(self.tool, embedding), + "x2text_config": ToolAdapter.get_adapter_config(self.tool, x2text), + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + } + # JSON keys are sorted to ensure that the same key gets hashed even in + # case where the fields are reordered. + hashed_index_key = ToolUtils.hash_str(json.dumps(index_key, sort_keys=True)) + return hashed_index_key diff --git a/src/unstract/sdk/llm.py b/src/unstract/sdk/llm.py index bff9e536..1da53fc7 100644 --- a/src/unstract/sdk/llm.py +++ b/src/unstract/sdk/llm.py @@ -9,7 +9,7 @@ from unstract.adapters.llm.llm_adapter import LLMAdapter from unstract.sdk.adapters import ToolAdapter -from unstract.sdk.constants import LogLevel, ToolSettingsKey +from unstract.sdk.constants import LogLevel from unstract.sdk.exceptions import SdkError from unstract.sdk.tool.base import BaseTool from unstract.sdk.utils.callback_manager import ( @@ -24,12 +24,8 @@ class ToolLLM: json_regex = re.compile(r"\{(?:.|\n)*\}") - def __init__( - self, - tool: BaseTool, - tool_settings: dict[str, str] = {}, - ): - """ + def __init__(self, tool: BaseTool): + """ToolLLM constructor. Notes: - "Azure OpenAI" : Environment variables required @@ -42,9 +38,7 @@ def __init__( self.tool = tool self.max_tokens = 1024 * 4 self.llm_adapters = adapters - self.llm_adapter_instance_id = tool_settings.get( - ToolSettingsKey.LLM_ADAPTER_ID - ) + self.llm_config_data: Optional[dict[str, Any]] = None @classmethod def run_completion( @@ -84,47 +78,35 @@ def run_completion( time.sleep(5) return None - def get_llm(self, adapter_instance_id: Optional[str] = None) -> LLM: + def get_llm(self, adapter_instance_id: str) -> LLM: """Returns the LLM object for the tool. Returns: LLM: The LLM object for the tool. (llama_index.llms.base.LLM) """ - adapter_instance_id = ( - adapter_instance_id - if adapter_instance_id - else self.llm_adapter_instance_id - ) - # Support for get_llm using adapter_instance_id - if adapter_instance_id is not None: - try: - llm_config_data = ToolAdapter.get_adapter_config( - self.tool, adapter_instance_id - ) - llm_adapter_id = llm_config_data.get(Common.ADAPTER_ID) - if llm_adapter_id in self.llm_adapters: - llm_adapter = self.llm_adapters[llm_adapter_id][ - Common.METADATA - ][Common.ADAPTER] - llm_metadata = llm_config_data.get(Common.ADAPTER_METADATA) - llm_adapter_class: LLMAdapter = llm_adapter(llm_metadata) - llm_instance: LLM = llm_adapter_class.get_llm_instance() - return llm_instance - else: - raise SdkError( - f"LLM adapter not supported : " f"{llm_adapter_id}" - ) - except Exception as e: - self.tool.stream_log( - log=f"Unable to get llm instance: {e}", level=LogLevel.ERROR + try: + llm_config_data = ToolAdapter.get_adapter_config( + self.tool, adapter_instance_id + ) + llm_adapter_id = llm_config_data.get(Common.ADAPTER_ID) + if llm_adapter_id not in self.llm_adapters: + raise SdkError( + f"LLM adapter not supported : " f"{llm_adapter_id}" ) - raise SdkError(f"Error getting llm instance: {e}") - else: - raise SdkError( - f"Adapter_instance_id does not have " - f"a valid value: {adapter_instance_id}" + + llm_adapter = self.llm_adapters[llm_adapter_id][ + Common.METADATA + ][Common.ADAPTER] + llm_metadata = llm_config_data.get(Common.ADAPTER_METADATA) + llm_adapter_class: LLMAdapter = llm_adapter(llm_metadata) + llm_instance: LLM = llm_adapter_class.get_llm_instance() + return llm_instance + except Exception as e: + self.tool.stream_log( + log=f"Unable to get llm instance: {e}", level=LogLevel.ERROR ) + raise SdkError(f"Error getting llm instance: {e}") def get_max_tokens(self, reserved_for_output: int = 0) -> int: """Returns the maximum number of tokens that can be used for the LLM. diff --git a/src/unstract/sdk/utils/tool_utils.py b/src/unstract/sdk/utils/tool_utils.py index a66c8b33..ee70ff5e 100644 --- a/src/unstract/sdk/utils/tool_utils.py +++ b/src/unstract/sdk/utils/tool_utils.py @@ -5,8 +5,6 @@ import magic -from unstract.sdk.constants import FileReaderSettings - class ToolUtils: """Class containing utility methods.""" @@ -38,18 +36,24 @@ def hash_str(string_to_hash: Any, hash_method: str = "sha256") -> str: raise ValueError(f"Unsupported hash_method: {hash_method}") @staticmethod - def get_hash_from_file(file_path: str): - hashes = [] - chunk_size = FileReaderSettings.FILE_READER_CHUNK_SIZE - - with open(file_path, "rb") as f: - while True: - chunk = f.read(chunk_size) - if not chunk: - break # End of file - hashes.append(ToolUtils.hash_str(chunk)) - hash_value = ToolUtils.hash_str("".join(hashes)) - return hash_value + def get_hash_from_file(file_path: str) -> str: + """Computes the hash for a file. + + Uses sha256 to compute the file hash through a buffered read. + + Args: + file_path (str): Path to file that needs to be hashed + + Returns: + str: SHA256 hash of the file + """ + h = sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(file_path, "rb", buffering=0) as f: + while n := f.readinto(mv): + h.update(mv[:n]) + return str(h.hexdigest()) @staticmethod def load_json(file_to_load: str) -> dict[str, Any]: diff --git a/src/unstract/sdk/vector_db.py b/src/unstract/sdk/vector_db.py index 08c2ce4b..9c015be7 100644 --- a/src/unstract/sdk/vector_db.py +++ b/src/unstract/sdk/vector_db.py @@ -10,7 +10,7 @@ from unstract.adapters.vectordb.constants import VectorDbConstants from unstract.sdk.adapters import ToolAdapter -from unstract.sdk.constants import LogLevel, ToolEnv, ToolSettingsKey +from unstract.sdk.constants import LogLevel, ToolEnv from unstract.sdk.exceptions import SdkError from unstract.sdk.platform import PlatformHelper from unstract.sdk.tool.base import BaseTool @@ -21,12 +21,9 @@ class ToolVectorDB: """Class to handle VectorDB for Unstract Tools.""" - def __init__(self, tool: BaseTool, tool_settings: dict[str, str] = {}): + def __init__(self, tool: BaseTool): self.tool = tool self.vector_db_adapters = adapters - self.vector_db_adapter_instance_id = tool_settings.get( - ToolSettingsKey.VECTOR_DB_ADAPTER_ID - ) def __get_org_id(self) -> str: platform_helper = PlatformHelper( @@ -45,49 +42,43 @@ def __get_org_id(self) -> str: def get_vector_db( self, adapter_instance_id: str, embedding_dimension: int ) -> Union[BasePydanticVectorStore, VectorStore]: - adapter_instance_id = ( - adapter_instance_id - if adapter_instance_id - else self.vector_db_adapter_instance_id - ) - if adapter_instance_id is not None: - try: - vector_db_config = ToolAdapter.get_adapter_config( - self.tool, adapter_instance_id - ) - vector_db_adapter_id = vector_db_config.get(Common.ADAPTER_ID) - if vector_db_adapter_id in self.vector_db_adapters: - vector_db_adapter = self.vector_db_adapters[ - vector_db_adapter_id - ][Common.METADATA][Common.ADAPTER] - vector_db_metadata = vector_db_config.get( - Common.ADAPTER_METADATA - ) - org = self.__get_org_id() - # Adding the collection prefix and embedding type - # to the metadata - vector_db_metadata[VectorDbConstants.VECTOR_DB_NAME] = org - vector_db_metadata[ - VectorDbConstants.EMBEDDING_DIMENSION - ] = embedding_dimension + """Gets an instance of LlamaIndex's VectorStore. + + Args: + adapter_instance_id (str): UUID of the vector DB adapter + embedding_dimension (int): Embedding dimension for the vector store - vector_db_adapter_class = vector_db_adapter( - vector_db_metadata - ) - return vector_db_adapter_class.get_vector_db_instance() - else: - raise SdkError( - f"VectorDB adapter not supported : " - f"{vector_db_adapter_id}" - ) - except Exception as e: - self.tool.stream_log( - log=f"Unable to get vector_db {adapter_instance_id}: {e}", - level=LogLevel.ERROR, + Returns: + Union[BasePydanticVectorStore, VectorStore]: Vector store instance + """ + try: + vector_db_config = ToolAdapter.get_adapter_config( + self.tool, adapter_instance_id + ) + vector_db_adapter_id = vector_db_config.get(Common.ADAPTER_ID) + if vector_db_adapter_id not in self.vector_db_adapters: + raise SdkError( + f"VectorDB adapter not supported : " + f"{vector_db_adapter_id}" ) - raise SdkError(f"Error getting vectorDB instance: {e}") - else: - raise SdkError( - f"Adapter_instance_id does not have " - f"a valid value: {adapter_instance_id}" + + vector_db_adapter = self.vector_db_adapters[vector_db_adapter_id][ + Common.METADATA + ][Common.ADAPTER] + vector_db_metadata = vector_db_config.get(Common.ADAPTER_METADATA) + org = self.__get_org_id() + # Adding the collection prefix and embedding type + # to the metadata + vector_db_metadata[VectorDbConstants.VECTOR_DB_NAME] = org + vector_db_metadata[ + VectorDbConstants.EMBEDDING_DIMENSION + ] = embedding_dimension + + vector_db_adapter_class = vector_db_adapter(vector_db_metadata) + return vector_db_adapter_class.get_vector_db_instance() + except Exception as e: + self.tool.stream_log( + log=f"Unable to get vector_db {adapter_instance_id}: {e}", + level=LogLevel.ERROR, ) + raise SdkError(f"Error getting vectorDB instance: {e}") diff --git a/src/unstract/sdk/x2txt.py b/src/unstract/sdk/x2txt.py index d6e37327..3809226f 100644 --- a/src/unstract/sdk/x2txt.py +++ b/src/unstract/sdk/x2txt.py @@ -1,5 +1,4 @@ from abc import ABCMeta -from typing import Optional from unstract.adapters.constants import Common from unstract.adapters.x2text import adapters @@ -8,6 +7,7 @@ from unstract.sdk.adapters import ToolAdapter from unstract.sdk.constants import LogLevel +from unstract.sdk.exceptions import SdkError from unstract.sdk.tool.base import BaseTool @@ -16,7 +16,7 @@ def __init__(self, tool: BaseTool): self.tool = tool self.x2text_adapters = adapters - def get_x2text(self, adapter_instance_id: str) -> Optional[X2TextAdapter]: + def get_x2text(self, adapter_instance_id: str) -> X2TextAdapter: try: x2text_config = ToolAdapter.get_adapter_config( self.tool, adapter_instance_id @@ -49,4 +49,4 @@ def get_x2text(self, adapter_instance_id: str) -> Optional[X2TextAdapter]: log=f"Unable to get x2text adapter {adapter_instance_id}: {e}", level=LogLevel.ERROR, ) - return None + raise SdkError(f"Error getting vectorDB instance: {e}")