Skip to content
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ repos:
rev: 24.2.0
hooks:
- id: black
args: [--config=pyproject.toml, -l 80]
args: [--config=pyproject.toml, -l 88]
language: system
exclude: |
(?x)^(
Expand All @@ -60,7 +60,7 @@ repos:
rev: 7.0.0
hooks:
- id: flake8
args: [--max-line-length=80]
args: [--max-line-length=88]
exclude: |
(?x)^(
.*migrations/.*\.py|
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ lint = [
]

[tool.isort]
line_length = 80
line_length = 88
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
Expand Down
2 changes: 1 addition & 1 deletion src/unstract/sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.23.0"
__version__ = "0.24.0"


def get_sdk_version():
Expand Down
4 changes: 0 additions & 4 deletions src/unstract/sdk/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,3 @@ class ToolSettingsKey:
EMBEDDING_ADAPTER_ID = "embeddingAdapterId"
VECTOR_DB_ADAPTER_ID = "vectorDbAdapterId"
X2TEXT_ADAPTER_ID = "x2TextAdapterId"


class FileReaderSettings:
FILE_READER_CHUNK_SIZE = 8192
53 changes: 21 additions & 32 deletions src/unstract/sdk/embedding.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,49 @@
from typing import Optional

from llama_index.core.embeddings import BaseEmbedding
from unstract.adapters.constants import Common
from unstract.adapters.embedding import adapters

from unstract.sdk.adapters import ToolAdapter
from unstract.sdk.constants import LogLevel, ToolSettingsKey
from unstract.sdk.constants import LogLevel
from unstract.sdk.exceptions import SdkError
from unstract.sdk.tool.base import BaseTool


class ToolEmbedding:
__TEST_SNIPPET = "Hello, I am Unstract"

def __init__(self, tool: BaseTool, tool_settings: dict[str, str] = {}):
def __init__(self, tool: BaseTool):
self.tool = tool
self.max_tokens = 1024 * 16
self.embedding_adapters = adapters
self.embedding_adapter_instance_id = tool_settings.get(
ToolSettingsKey.EMBEDDING_ADAPTER_ID
)
self.embedding_adapter_id: Optional[str] = None

def get_embedding(
self, adapter_instance_id: Optional[str] = None
) -> BaseEmbedding:
adapter_instance_id = (
adapter_instance_id
if adapter_instance_id
else self.embedding_adapter_instance_id
)
if not adapter_instance_id:
raise SdkError(
f"Adapter_instance_id does not have "
f"a valid value: {adapter_instance_id}"
)
def get_embedding(self, adapter_instance_id: str) -> BaseEmbedding:
"""Gets an instance of LlamaIndex's embedding object.

Args:
adapter_instance_id (str): UUID of the embedding adapter

Returns:
BaseEmbedding: Embedding instance
"""
try:
embedding_config_data = ToolAdapter.get_adapter_config(
self.tool, adapter_instance_id
)
embedding_adapter_id = embedding_config_data.get(Common.ADAPTER_ID)
self.embedding_adapter_id = embedding_adapter_id
if embedding_adapter_id in self.embedding_adapters:
embedding_adapter = self.embedding_adapters[
embedding_adapter_id
][Common.METADATA][Common.ADAPTER]
embedding_metadata = embedding_config_data.get(
Common.ADAPTER_METADATA
)
embedding_adapter_class = embedding_adapter(embedding_metadata)
return embedding_adapter_class.get_embedding_instance()
else:
if embedding_adapter_id not in self.embedding_adapters:
raise SdkError(
f"Embedding adapter not supported : "
f"{embedding_adapter_id}"
)

embedding_adapter = self.embedding_adapters[embedding_adapter_id][
Common.METADATA
][Common.ADAPTER]
embedding_metadata = embedding_config_data.get(
Common.ADAPTER_METADATA
)
embedding_adapter_class = embedding_adapter(embedding_metadata)
return embedding_adapter_class.get_embedding_instance()
except Exception as e:
self.tool.stream_log(
log=f"Error getting embedding: {e}", level=LogLevel.ERROR
Expand Down
114 changes: 50 additions & 64 deletions src/unstract/sdk/index.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from typing import Optional

from llama_index.core import Document
Expand All @@ -14,14 +15,13 @@
from unstract.adapters.exceptions import AdapterError
from unstract.adapters.x2text.x2text_adapter import X2TextAdapter

from unstract.sdk.adapters import ToolAdapter
from unstract.sdk.constants import LogLevel, ToolEnv
from unstract.sdk.embedding import ToolEmbedding
from unstract.sdk.exceptions import IndexingError, SdkError
from unstract.sdk.tool.base import BaseTool
from unstract.sdk.utils import ToolUtils
from unstract.sdk.utils.callback_manager import (
CallbackManager as UNCallbackManager,
)
from unstract.sdk.utils.callback_manager import CallbackManager as UNCallbackManager
from unstract.sdk.vector_db import ToolVectorDB
from unstract.sdk.x2txt import X2Text

Expand All @@ -31,18 +31,9 @@ def __init__(self, tool: BaseTool):
# TODO: Inherit from StreamMixin and avoid using BaseTool
self.tool = tool

def get_text_from_index(
self, embedding_type: str, vector_db: str, doc_id: str
):
def get_text_from_index(self, embedding_type: str, vector_db: str, doc_id: str):
embedd_helper = ToolEmbedding(tool=self.tool)
embedding_li = embedd_helper.get_embedding(
adapter_instance_id=embedding_type
)
if embedding_li is None:
self.tool.stream_log(
f"Error loading {embedding_type}", level=LogLevel.ERROR
)
raise SdkError(f"Error loading {embedding_type}")
embedding_li = embedd_helper.get_embedding(adapter_instance_id=embedding_type)
embedding_dimension = embedd_helper.get_embedding_length(embedding_li)

vdb_helper = ToolVectorDB(
Expand All @@ -53,12 +44,6 @@ def get_text_from_index(
embedding_dimension=embedding_dimension,
)

if vector_db_li is None:
self.tool.stream_log(
f"Error loading {vector_db}", level=LogLevel.ERROR
)
raise SdkError(f"Error loading {vector_db}")

try:
self.tool.stream_log(f">>> Querying {vector_db}...")
self.tool.stream_log(f">>> {doc_id}")
Expand Down Expand Up @@ -149,48 +134,33 @@ def index_file(
Returns:
str: A unique ID for the file and indexing arguments combination
"""
# Make file content hash if not available
if not file_hash:
file_hash = ToolUtils.get_hash_from_file(file_path=file_path)

doc_id = ToolIndex.generate_file_id(
doc_id = self.generate_file_id(
tool_id=tool_id,
file_hash=file_hash,
vector_db=vector_db,
embedding=embedding_type,
x2text=x2text_adapter,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
chunk_size=str(chunk_size),
chunk_overlap=str(chunk_overlap),
file_path=file_path,
file_hash=file_hash,
)

self.tool.stream_log(f"Checking if doc_id {doc_id} exists")

vdb_helper = ToolVectorDB(
tool=self.tool,
)

# Get embedding instance
embedd_helper = ToolEmbedding(tool=self.tool)
embedding_li = embedd_helper.get_embedding(adapter_instance_id=embedding_type)
embedding_dimension = embedd_helper.get_embedding_length(embedding_li)

embedding_li = embedd_helper.get_embedding(
adapter_instance_id=embedding_type
# Get vectorDB instance
vdb_helper = ToolVectorDB(
tool=self.tool,
)
if embedding_li is None:
self.tool.stream_log(
f"Error loading {embedding_type}", level=LogLevel.ERROR
)
raise SdkError(f"Error loading {embedding_type}")

embedding_dimension = embedd_helper.get_embedding_length(embedding_li)
vector_db_li = vdb_helper.get_vector_db(
adapter_instance_id=vector_db,
embedding_dimension=embedding_dimension,
)
if vector_db_li is None:
self.tool.stream_log(
f"Error loading {vector_db}", level=LogLevel.ERROR
)
raise SdkError(f"Error loading {vector_db}")

# Checking if document is already indexed against doc_id
doc_id_eq_filter = MetadataFilter.from_dict(
{"key": "doc_id", "operator": FilterOperator.EQ, "value": doc_id}
)
Expand Down Expand Up @@ -275,26 +245,20 @@ def index_file(
parser = SimpleNodeParser.from_defaults(
chunk_size=len(documents[0].text) + 10, chunk_overlap=0
)
nodes = parser.get_nodes_from_documents(
documents, show_progress=True
)
nodes = parser.get_nodes_from_documents(documents, show_progress=True)
node = nodes[0]
node.embedding = embedding_li.get_query_embedding(" ")
vector_db_li.add(nodes=[node])
self.tool.stream_log("Added node to vector db")
else:
storage_context = StorageContext.from_defaults(
vector_store=vector_db_li
)
storage_context = StorageContext.from_defaults(vector_store=vector_db_li)
parser = SimpleNodeParser.from_defaults(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)

# Set callback_manager to collect Usage stats
callback_manager = UNCallbackManager.set_callback_manager(
platform_api_key=self.tool.get_env_or_die(
ToolEnv.PLATFORM_API_KEY
),
platform_api_key=self.tool.get_env_or_die(ToolEnv.PLATFORM_API_KEY),
embedding=embedding_li,
)

Expand All @@ -319,31 +283,53 @@ def index_file(
self.tool.stream_log("File has been indexed successfully")
return doc_id

@staticmethod
def generate_file_id(
self,
tool_id: str,
file_hash: str,
vector_db: str,
embedding: str,
x2text: str,
chunk_size: str,
chunk_overlap: str,
file_path: Optional[str] = None,
file_hash: Optional[str] = None,
) -> str:
"""Generates a unique ID useful for identifying files during indexing.

Args:
tool_id (str): Unique ID of the tool developed / exported
file_hash (str): Hash of the file contents
tool_id (str): Unique ID of the tool or workflow
vector_db (str): UUID of the vector DB adapter
embedding (str): UUID of the embedding adapter
x2text (str): UUID of the X2Text adapter
chunk_size (str): Chunk size for indexing
chunk_overlap (str): Chunk overlap for indexing
file_path (Optional[str]): Path to the file that needs to be indexed.
Defaults to None. One of file_path or file_hash needs to be specified.
file_hash (Optional[str], optional): SHA256 hash of the file.
Defaults to None. If None, the hash is generated with file_path.

Returns:
str: Key representing unique ID for a file
"""
return (
f"{tool_id}|{vector_db}|{embedding}|{x2text}|"
f"{chunk_size}|{chunk_overlap}|{file_hash}"
)
if not file_path and not file_hash:
raise ValueError("One of `file_path` or `file_hash` need to be provided")

if not file_hash:
file_hash = ToolUtils.get_hash_from_file(file_path=file_path)

# Whole adapter config is used currently even though it contains some keys
# which might not be relevant to indexing. This is easier for now than
# marking certain keys of the adapter config as necessary.
index_key = {
"tool_id": tool_id,
"file_hash": file_hash,
"vector_db_config": ToolAdapter.get_adapter_config(self.tool, vector_db),
"embedding_config": ToolAdapter.get_adapter_config(self.tool, embedding),
"x2text_config": ToolAdapter.get_adapter_config(self.tool, x2text),
"chunk_size": chunk_size,
"chunk_overlap": chunk_overlap,
}
# JSON keys are sorted to ensure that the same key gets hashed even in
# case where the fields are reordered.
hashed_index_key = ToolUtils.hash_str(json.dumps(index_key, sort_keys=True))
return hashed_index_key
Loading