diff --git a/pyproject.toml b/pyproject.toml index c24b8ec..cb7fb59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,17 +157,15 @@ select = [ fixable = ["ALL"] ignore = [ "D100", # Missing docstring in public module + "D101", # Missing docstring in public class + "D102", # Missing docstring in public method "D104", # Missing docstring in public package + "D103", # Missing docstring in public function + "D106", # Missing docstring in public nested class "ANN101", # Missing type annotation for self "ANN102", # Missing type annotation for cls ] -[tool.ruff.lint.per-file-ignores] -"test_*.py" = [ - "ANN201", # Missing return type annotation for public function - "D103", # Missing docstring in public function -] - [tool.ruff.format] quote-style = "double" indent-style = "space" diff --git a/src/unstract/sdk/adapter.py b/src/unstract/sdk/adapter.py index 5d3d381..5cb5a14 100644 --- a/src/unstract/sdk/adapter.py +++ b/src/unstract/sdk/adapter.py @@ -1,9 +1,8 @@ import json -from typing import Any, Optional +from typing import Any import requests from requests.exceptions import ConnectionError, HTTPError - from unstract.sdk.adapters.utils import AdapterUtils from unstract.sdk.constants import AdapterKeys, LogLevel, ToolEnv from unstract.sdk.exceptions import SdkError @@ -25,8 +24,7 @@ def __init__( platform_host: str, platform_port: str, ) -> None: - """ - Args: + """Args: tool (AbstractTool): Instance of AbstractTool platform_host (str): Host of platform service platform_port (str): Port of platform service @@ -89,7 +87,7 @@ def _get_adapter_configuration( @staticmethod def get_adapter_config( tool: BaseTool, adapter_instance_id: str - ) -> Optional[dict[str, Any]]: + ) -> dict[str, Any] | None: """Get adapter spec by the help of unstract DB tool. This method first checks if the adapter_instance_id matches diff --git a/src/unstract/sdk/adapters/adapterkit.py b/src/unstract/sdk/adapters/adapterkit.py index ae948e2..9c08bd2 100644 --- a/src/unstract/sdk/adapters/adapterkit.py +++ b/src/unstract/sdk/adapters/adapterkit.py @@ -2,7 +2,6 @@ from typing import Any from singleton_decorator import singleton - from unstract.sdk.adapters import AdapterDict from unstract.sdk.adapters.base import Adapter from unstract.sdk.adapters.constants import Common @@ -34,16 +33,14 @@ def adapters(self) -> AdapterDict: def get_adapter_class_by_adapter_id(self, adapter_id: str) -> Adapter: if adapter_id in self._adapters: - adapter_class: Adapter = self._adapters[adapter_id][ - Common.METADATA - ][Common.ADAPTER] + adapter_class: Adapter = self._adapters[adapter_id][Common.METADATA][ + Common.ADAPTER + ] return adapter_class else: raise RuntimeError(f"Couldn't obtain adapter for {adapter_id}") - def get_adapter_by_id( - self, adapter_id: str, *args: Any, **kwargs: Any - ) -> Adapter: + def get_adapter_by_id(self, adapter_id: str, *args: Any, **kwargs: Any) -> Adapter: """Instantiates and returns a adapter. Args: @@ -55,17 +52,13 @@ def get_adapter_by_id( Returns: Adapter: Concrete impl of the `Adapter` base """ - adapter_class: Adapter = self.get_adapter_class_by_adapter_id( - adapter_id - ) + adapter_class: Adapter = self.get_adapter_class_by_adapter_id(adapter_id) return adapter_class(*args, **kwargs) def get_adapters_list(self) -> list[dict[str, Any]]: adapters = [] for adapter_id, adapter_registry_metadata in self._adapters.items(): - m: Adapter = adapter_registry_metadata[Common.METADATA][ - Common.ADAPTER - ] + m: Adapter = adapter_registry_metadata[Common.METADATA][Common.ADAPTER] _id = m.get_id() name = m.get_name() adapter_type = m.get_adapter_type().name diff --git a/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py b/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py index d533fb9..872cb39 100644 --- a/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py +++ b/src/unstract/sdk/adapters/embedding/azure_open_ai/src/azure_open_ai.py @@ -4,7 +4,6 @@ import httpx from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding - from unstract.sdk.adapters.embedding.embedding_adapter import EmbeddingAdapter from unstract.sdk.adapters.embedding.helper import EmbeddingHelper from unstract.sdk.adapters.exceptions import AdapterError @@ -71,5 +70,3 @@ def get_embedding_instance(self) -> BaseEmbedding: return embedding except Exception as e: raise AdapterError(str(e)) - - diff --git a/src/unstract/sdk/adapters/embedding/bedrock/src/bedrock.py b/src/unstract/sdk/adapters/embedding/bedrock/src/bedrock.py index f2633e1..2b2cdb0 100644 --- a/src/unstract/sdk/adapters/embedding/bedrock/src/bedrock.py +++ b/src/unstract/sdk/adapters/embedding/bedrock/src/bedrock.py @@ -3,11 +3,11 @@ from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.bedrock import BedrockEmbedding - from unstract.sdk.adapters.embedding.embedding_adapter import EmbeddingAdapter from unstract.sdk.adapters.embedding.helper import EmbeddingHelper from unstract.sdk.adapters.exceptions import AdapterError + class Constants: MODEL = "model" TIMEOUT = "timeout" @@ -18,6 +18,7 @@ class Constants: DEFAULT_TIMEOUT = 240 DEFAULT_MAX_RETRIES = 3 + class Bedrock(EmbeddingAdapter): def __init__(self, settings: dict[str, Any]): super().__init__("Bedrock") @@ -66,5 +67,3 @@ def get_embedding_instance(self) -> BaseEmbedding: return embedding except Exception as e: raise AdapterError(str(e)) - - \ No newline at end of file diff --git a/src/unstract/sdk/adapters/embedding/embedding_adapter.py b/src/unstract/sdk/adapters/embedding/embedding_adapter.py index 8283f2d..98e3cf2 100644 --- a/src/unstract/sdk/adapters/embedding/embedding_adapter.py +++ b/src/unstract/sdk/adapters/embedding/embedding_adapter.py @@ -3,11 +3,10 @@ from llama_index.core import MockEmbedding from llama_index.core.embeddings import BaseEmbedding - from unstract.sdk.adapters.base import Adapter +from unstract.sdk.adapters.embedding.helper import EmbeddingHelper from unstract.sdk.adapters.enums import AdapterTypes -from unstract.sdk.adapters.embedding.helper import EmbeddingHelper class EmbeddingAdapter(Adapter, ABC): def __init__(self, name: str): @@ -47,8 +46,8 @@ def get_embedding_instance(self, embed_config: dict[str, Any]) -> BaseEmbedding: Raises exceptions for any error """ return MockEmbedding(embed_dim=1) - + def test_connection(self) -> bool: embedding = self.get_embedding_instance() test_result: bool = EmbeddingHelper.test_embedding_instance(embedding) - return test_result \ No newline at end of file + return test_result diff --git a/src/unstract/sdk/adapters/embedding/helper.py b/src/unstract/sdk/adapters/embedding/helper.py index 256dedc..3596df4 100644 --- a/src/unstract/sdk/adapters/embedding/helper.py +++ b/src/unstract/sdk/adapters/embedding/helper.py @@ -1,8 +1,7 @@ import logging -from typing import Any, Optional +from typing import Any from llama_index.core.embeddings import BaseEmbedding - from unstract.sdk.adapters.exceptions import AdapterError logger = logging.getLogger(__name__) @@ -28,7 +27,7 @@ def get_embedding_batch_size(config: dict[str, Any]) -> int: return embedding_batch_size @staticmethod - def test_embedding_instance(embedding: Optional[BaseEmbedding]) -> bool: + def test_embedding_instance(embedding: BaseEmbedding | None) -> bool: try: if embedding is None: return False diff --git a/src/unstract/sdk/adapters/embedding/hugging_face/src/hugging_face.py b/src/unstract/sdk/adapters/embedding/hugging_face/src/hugging_face.py index b9e01fe..a0e0b2b 100644 --- a/src/unstract/sdk/adapters/embedding/hugging_face/src/hugging_face.py +++ b/src/unstract/sdk/adapters/embedding/hugging_face/src/hugging_face.py @@ -1,9 +1,8 @@ import os -from typing import Any, Optional +from typing import Any from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.huggingface import HuggingFaceEmbedding - from unstract.sdk.adapters.embedding.embedding_adapter import EmbeddingAdapter from unstract.sdk.adapters.embedding.helper import EmbeddingHelper from unstract.sdk.adapters.exceptions import AdapterError @@ -45,7 +44,7 @@ def get_embedding_instance(self) -> BaseEmbedding: embedding_batch_size = EmbeddingHelper.get_embedding_batch_size( config=self.config ) - max_length: Optional[int] = ( + max_length: int | None = ( int(self.config.get(Constants.MAX_LENGTH, 0)) if self.config.get(Constants.MAX_LENGTH) else None @@ -61,5 +60,3 @@ def get_embedding_instance(self) -> BaseEmbedding: return embedding except Exception as e: raise AdapterError(str(e)) - - diff --git a/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py b/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py index fd1499d..68b5b8a 100644 --- a/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py +++ b/src/unstract/sdk/adapters/embedding/ollama/src/ollama.py @@ -3,7 +3,6 @@ from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.ollama import OllamaEmbedding - from unstract.sdk.adapters.embedding.embedding_adapter import EmbeddingAdapter from unstract.sdk.adapters.embedding.helper import EmbeddingHelper from unstract.sdk.adapters.exceptions import AdapterError @@ -55,5 +54,3 @@ def get_embedding_instance(self) -> BaseEmbedding: return embedding except Exception as e: raise AdapterError(str(e)) - - diff --git a/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py b/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py index 0ca1ce7..781e849 100644 --- a/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py +++ b/src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py @@ -4,7 +4,6 @@ import httpx from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.openai import OpenAIEmbedding - from unstract.sdk.adapters.embedding.embedding_adapter import EmbeddingAdapter from unstract.sdk.adapters.exceptions import AdapterError @@ -66,5 +65,3 @@ def get_embedding_instance(self) -> BaseEmbedding: return embedding except Exception as e: raise AdapterError(str(e)) - - diff --git a/src/unstract/sdk/adapters/embedding/palm/src/palm.py b/src/unstract/sdk/adapters/embedding/palm/src/palm.py index 25cb411..8e31e51 100644 --- a/src/unstract/sdk/adapters/embedding/palm/src/palm.py +++ b/src/unstract/sdk/adapters/embedding/palm/src/palm.py @@ -3,7 +3,6 @@ from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.google import GooglePaLMEmbedding - from unstract.sdk.adapters.embedding.embedding_adapter import EmbeddingAdapter from unstract.sdk.adapters.embedding.helper import EmbeddingHelper from unstract.sdk.adapters.exceptions import AdapterError @@ -55,5 +54,3 @@ def get_embedding_instance(self) -> BaseEmbedding: return embedding except Exception as e: raise AdapterError(str(e)) - - diff --git a/src/unstract/sdk/adapters/embedding/qdrant_fast_embed/src/qdrant_fast_embed.py b/src/unstract/sdk/adapters/embedding/qdrant_fast_embed/src/qdrant_fast_embed.py index 94bb4d2..0922256 100644 --- a/src/unstract/sdk/adapters/embedding/qdrant_fast_embed/src/qdrant_fast_embed.py +++ b/src/unstract/sdk/adapters/embedding/qdrant_fast_embed/src/qdrant_fast_embed.py @@ -3,7 +3,6 @@ from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.fastembed import FastEmbedEmbedding - from unstract.sdk.adapters.embedding.embedding_adapter import EmbeddingAdapter from unstract.sdk.adapters.exceptions import AdapterError @@ -44,5 +43,3 @@ def get_embedding_instance(self) -> BaseEmbedding: return embedding except Exception as e: raise AdapterError(str(e)) - - diff --git a/src/unstract/sdk/adapters/embedding/vertex_ai/src/vertex_ai.py b/src/unstract/sdk/adapters/embedding/vertex_ai/src/vertex_ai.py index 395329e..a8840da 100644 --- a/src/unstract/sdk/adapters/embedding/vertex_ai/src/vertex_ai.py +++ b/src/unstract/sdk/adapters/embedding/vertex_ai/src/vertex_ai.py @@ -4,21 +4,21 @@ from google.auth.transport import requests as google_requests from google.oauth2.service_account import Credentials - from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.vertex import VertexTextEmbedding - from unstract.sdk.adapters.embedding.embedding_adapter import EmbeddingAdapter from unstract.sdk.adapters.embedding.helper import EmbeddingHelper from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.exceptions import EmbeddingError + class Constants: MODEL = "model" PROJECT = "project" JSON_CREDENTIALS = "json_credentials" EMBED_MODE = "embed_mode" + class VertexAIEmbedding(EmbeddingAdapter): def __init__(self, settings: dict[str, Any]): super().__init__("Bedrock") @@ -45,7 +45,7 @@ def get_provider() -> str: @staticmethod def get_icon() -> str: return "/icons/adapter-icons/VertexAI.png" - + def get_embedding_instance(self) -> BaseEmbedding: try: embedding_batch_size = EmbeddingHelper.get_embedding_batch_size( @@ -69,11 +69,9 @@ def get_embedding_instance(self) -> BaseEmbedding: ) return embedding except json.JSONDecodeError: - raise EmbeddingError( - "Credentials is not a valid service account JSON, " - "please provide a valid JSON." - ) + raise EmbeddingError( + "Credentials is not a valid service account JSON, " + "please provide a valid JSON." + ) except Exception as e: raise AdapterError(str(e)) - - \ No newline at end of file diff --git a/src/unstract/sdk/adapters/llm/anthropic/src/anthropic.py b/src/unstract/sdk/adapters/llm/anthropic/src/anthropic.py index 2414d77..1890709 100644 --- a/src/unstract/sdk/adapters/llm/anthropic/src/anthropic.py +++ b/src/unstract/sdk/adapters/llm/anthropic/src/anthropic.py @@ -5,7 +5,6 @@ from llama_index.core.llms import LLM from llama_index.llms.anthropic import Anthropic from llama_index.llms.anthropic.base import DEFAULT_ANTHROPIC_MAX_TOKENS - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.llm.constants import LLMKeys from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter @@ -23,6 +22,7 @@ class Constants: ENABLE_THINKING = "enable_thinking" BUDGET_TOKENS = "budget_tokens" + class AnthropicLLM(LLMAdapter): def __init__(self, settings: dict[str, Any]): super().__init__("Anthropic") @@ -63,7 +63,7 @@ def get_llm_instance(self) -> LLM: budget_tokens = self.config.get(Constants.BUDGET_TOKENS) thinking_dict = {"type": "enabled", "budget_tokens": budget_tokens} temperature = 1 - + try: llm: LLM = Anthropic( model=str(self.config.get(Constants.MODEL)), @@ -76,7 +76,7 @@ def get_llm_instance(self) -> LLM: ), temperature=temperature, max_tokens=max_tokens, - thinking_dict=thinking_dict + thinking_dict=thinking_dict, ) return llm except Exception as e: diff --git a/src/unstract/sdk/adapters/llm/bedrock/src/bedrock.py b/src/unstract/sdk/adapters/llm/bedrock/src/bedrock.py index d46f854..067e6bb 100644 --- a/src/unstract/sdk/adapters/llm/bedrock/src/bedrock.py +++ b/src/unstract/sdk/adapters/llm/bedrock/src/bedrock.py @@ -1,9 +1,8 @@ import os -from typing import Any, Optional +from typing import Any from llama_index.core.llms import LLM from llama_index.llms.bedrock import Bedrock - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.llm.constants import LLMKeys from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter @@ -47,11 +46,11 @@ def get_provider() -> str: @staticmethod def get_icon() -> str: - return "/icons/adapter-icons/Bedrock.png" + return "/icons/adapter-icons/Bedrock.png" def get_llm_instance(self) -> LLM: try: - context_size: Optional[int] = ( + context_size: int | None = ( int(self.config.get(Constants.CONTEXT_SIZE, 0)) if self.config.get(Constants.CONTEXT_SIZE) else None @@ -76,4 +75,4 @@ def get_llm_instance(self) -> LLM: ) return llm except Exception as e: - raise AdapterError(str(e)) \ No newline at end of file + raise AdapterError(str(e)) diff --git a/src/unstract/sdk/adapters/llm/llm_adapter.py b/src/unstract/sdk/adapters/llm/llm_adapter.py index 0c94f61..4159579 100644 --- a/src/unstract/sdk/adapters/llm/llm_adapter.py +++ b/src/unstract/sdk/adapters/llm/llm_adapter.py @@ -1,11 +1,9 @@ import logging import re from abc import ABC, abstractmethod -from typing import Optional from llama_index.core.llms import LLM, MockLLM from llama_index.llms.openai.utils import O1_MODELS - from unstract.sdk.adapters.base import Adapter from unstract.sdk.adapters.enums import AdapterTypes from unstract.sdk.adapters.exceptions import LLMError @@ -67,7 +65,7 @@ def get_llm_instance(self) -> LLM: return MockLLM() @staticmethod - def _test_llm_instance(llm: Optional[LLM]) -> bool: + def _test_llm_instance(llm: LLM | None) -> bool: if llm is None: raise LLMError( message="Unable to connect to LLM, please recheck the configuration", @@ -75,16 +73,13 @@ def _test_llm_instance(llm: Optional[LLM]) -> bool: ) # Get completion kwargs based on model capabilities completion_kwargs = {} - if hasattr(llm, 'model') and getattr(llm, 'model') not in O1_MODELS: - completion_kwargs['temperature'] = 0.003 - - if hasattr(llm, 'thinking_dict') and getattr(llm, 'thinking_dict') is not None: - completion_kwargs['temperature'] = 1 - - response = llm.complete( - "The capital of Tamilnadu is ", - **completion_kwargs - ) + if hasattr(llm, "model") and llm.model not in O1_MODELS: + completion_kwargs["temperature"] = 0.003 + + if hasattr(llm, "thinking_dict") and llm.thinking_dict is not None: + completion_kwargs["temperature"] = 1 + + response = llm.complete("The capital of Tamilnadu is ", **completion_kwargs) response_lower_case: str = response.text.lower() find_match = re.search("chennai", response_lower_case) if find_match: diff --git a/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py b/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py index 81692ab..0e54d6b 100644 --- a/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py +++ b/src/unstract/sdk/adapters/llm/open_ai/src/open_ai.py @@ -5,12 +5,12 @@ from llama_index.llms.openai import OpenAI from llama_index.llms.openai.utils import O1_MODELS from openai import APIError as OpenAIAPIError - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.llm.constants import LLMKeys from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter from unstract.sdk.exceptions import LLMError + class Constants: MODEL = "model" API_KEY = "api_key" @@ -23,7 +23,6 @@ class Constants: class OpenAILLM(LLMAdapter): - def __init__(self, settings: dict[str, Any]): super().__init__("OpenAI") self.config = settings @@ -61,9 +60,13 @@ def get_llm_instance(self) -> LLM: "api_key": str(self.config.get(Constants.API_KEY)), "api_base": str(self.config.get(Constants.API_BASE)), "api_version": str(self.config.get(Constants.API_VERSION)), - "max_retries": int(self.config.get(Constants.MAX_RETRIES, LLMKeys.DEFAULT_MAX_RETRIES)), + "max_retries": int( + self.config.get(Constants.MAX_RETRIES, LLMKeys.DEFAULT_MAX_RETRIES) + ), "api_type": "openai", - "timeout": float(self.config.get(Constants.TIMEOUT, LLMKeys.DEFAULT_TIMEOUT)), + "timeout": float( + self.config.get(Constants.TIMEOUT, LLMKeys.DEFAULT_TIMEOUT) + ), "max_tokens": max_tokens, } diff --git a/src/unstract/sdk/adapters/llm/palm/src/palm.py b/src/unstract/sdk/adapters/llm/palm/src/palm.py index c0bd310..426908b 100644 --- a/src/unstract/sdk/adapters/llm/palm/src/palm.py +++ b/src/unstract/sdk/adapters/llm/palm/src/palm.py @@ -1,10 +1,9 @@ import os -from typing import Any, Optional +from typing import Any from google.api_core.exceptions import GoogleAPICallError from llama_index.core.llms import LLM from llama_index.llms.palm import PaLM - from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter from unstract.sdk.exceptions import LLMError @@ -44,11 +43,9 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/PaLM.png" - - def get_llm_instance(self) -> LLM: try: - num_output: Optional[int] = ( + num_output: int | None = ( int(self.config.get(Constants.NUM_OUTPUT, Constants.DEFAULT_MAX_TOKENS)) if self.config.get(Constants.NUM_OUTPUT) is not None else None diff --git a/src/unstract/sdk/adapters/llm/replicate/src/replicate.py b/src/unstract/sdk/adapters/llm/replicate/src/replicate.py index a522be9..5f9ac42 100644 --- a/src/unstract/sdk/adapters/llm/replicate/src/replicate.py +++ b/src/unstract/sdk/adapters/llm/replicate/src/replicate.py @@ -3,7 +3,6 @@ from llama_index.core.llms import LLM from llama_index.llms.replicate import Replicate - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter @@ -40,8 +39,6 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/Replicate.png" - - @staticmethod def can_write() -> bool: return True diff --git a/src/unstract/sdk/adapters/llm/vertex_ai/src/vertex_ai.py b/src/unstract/sdk/adapters/llm/vertex_ai/src/vertex_ai.py index 6f8983f..9f238e5 100644 --- a/src/unstract/sdk/adapters/llm/vertex_ai/src/vertex_ai.py +++ b/src/unstract/sdk/adapters/llm/vertex_ai/src/vertex_ai.py @@ -1,22 +1,21 @@ import json import logging import os -from typing import Any, Optional +from typing import Any from google.auth.transport import requests as google_requests from google.oauth2.service_account import Credentials from llama_index.core.llms import LLM from llama_index.llms.vertex import Vertex +from unstract.sdk.adapters.exceptions import LLMError +from unstract.sdk.adapters.llm.constants import LLMKeys +from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter from vertexai.generative_models import Candidate, FinishReason, ResponseValidationError from vertexai.generative_models._generative_models import ( HarmBlockThreshold, HarmCategory, ) -from unstract.sdk.adapters.exceptions import LLMError -from unstract.sdk.adapters.llm.constants import LLMKeys -from unstract.sdk.adapters.llm.llm_adapter import LLMAdapter - logger = logging.getLogger(__name__) @@ -76,8 +75,6 @@ def get_provider() -> str: def get_icon() -> str: return "/icons/adapter-icons/VertexAI.png" - - def get_llm_instance(self) -> LLM: input_credentials = self.config.get(Constants.JSON_CREDENTIALS, "{}") try: @@ -112,9 +109,9 @@ def get_llm_instance(self) -> LLM: safety_settings_default_config, ) - vertex_safety_settings: dict[ - HarmCategory, HarmBlockThreshold - ] = self._get_vertex_safety_settings(safety_settings_user_config) + vertex_safety_settings: dict[HarmCategory, HarmBlockThreshold] = ( + self._get_vertex_safety_settings(safety_settings_user_config) + ) llm: LLM = Vertex( project=str(self.config.get(Constants.PROJECT)), @@ -131,55 +128,55 @@ def _get_vertex_safety_settings( self, safety_settings_user_config: dict[str, str] ) -> dict[HarmCategory, HarmBlockThreshold]: vertex_safety_settings: dict[HarmCategory, HarmBlockThreshold] = dict() - vertex_safety_settings[ - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT - ] = UNSTRACT_VERTEX_SAFETY_THRESHOLD_MAPPING[ - ( - safety_settings_user_config.get( - SafetySettingsConstants.DANGEROUS_CONTENT, - Constants.BLOCK_ONLY_HIGH, + vertex_safety_settings[HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT] = ( + UNSTRACT_VERTEX_SAFETY_THRESHOLD_MAPPING[ + ( + safety_settings_user_config.get( + SafetySettingsConstants.DANGEROUS_CONTENT, + Constants.BLOCK_ONLY_HIGH, + ) ) - ) - ] - vertex_safety_settings[ - HarmCategory.HARM_CATEGORY_HATE_SPEECH - ] = UNSTRACT_VERTEX_SAFETY_THRESHOLD_MAPPING[ - ( - safety_settings_user_config.get( - SafetySettingsConstants.HATE_SPEECH, - Constants.BLOCK_ONLY_HIGH, + ] + ) + vertex_safety_settings[HarmCategory.HARM_CATEGORY_HATE_SPEECH] = ( + UNSTRACT_VERTEX_SAFETY_THRESHOLD_MAPPING[ + ( + safety_settings_user_config.get( + SafetySettingsConstants.HATE_SPEECH, + Constants.BLOCK_ONLY_HIGH, + ) ) - ) - ] - vertex_safety_settings[ - HarmCategory.HARM_CATEGORY_HARASSMENT - ] = UNSTRACT_VERTEX_SAFETY_THRESHOLD_MAPPING[ - ( - safety_settings_user_config.get( - SafetySettingsConstants.HARASSMENT, - Constants.BLOCK_ONLY_HIGH, + ] + ) + vertex_safety_settings[HarmCategory.HARM_CATEGORY_HARASSMENT] = ( + UNSTRACT_VERTEX_SAFETY_THRESHOLD_MAPPING[ + ( + safety_settings_user_config.get( + SafetySettingsConstants.HARASSMENT, + Constants.BLOCK_ONLY_HIGH, + ) ) - ) - ] - vertex_safety_settings[ - HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT - ] = UNSTRACT_VERTEX_SAFETY_THRESHOLD_MAPPING[ - ( - safety_settings_user_config.get( - SafetySettingsConstants.SEXUAL_CONTENT, - Constants.BLOCK_ONLY_HIGH, + ] + ) + vertex_safety_settings[HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT] = ( + UNSTRACT_VERTEX_SAFETY_THRESHOLD_MAPPING[ + ( + safety_settings_user_config.get( + SafetySettingsConstants.SEXUAL_CONTENT, + Constants.BLOCK_ONLY_HIGH, + ) ) - ) - ] - vertex_safety_settings[ - HarmCategory.HARM_CATEGORY_UNSPECIFIED - ] = UNSTRACT_VERTEX_SAFETY_THRESHOLD_MAPPING[ - ( - safety_settings_user_config.get( - SafetySettingsConstants.OTHER, Constants.BLOCK_ONLY_HIGH + ] + ) + vertex_safety_settings[HarmCategory.HARM_CATEGORY_UNSPECIFIED] = ( + UNSTRACT_VERTEX_SAFETY_THRESHOLD_MAPPING[ + ( + safety_settings_user_config.get( + SafetySettingsConstants.OTHER, Constants.BLOCK_ONLY_HIGH + ) ) - ) - ] + ] + ) return vertex_safety_settings @staticmethod @@ -200,7 +197,7 @@ def parse_llm_err(e: ResponseValidationError) -> LLMError: "since its a completion call and not chat." ) resp = e.responses[0] - candidates: list["Candidate"] = resp.candidates + candidates: list[Candidate] = resp.candidates if not candidates: msg = str(resp.prompt_feedback) reason_messages = { @@ -241,7 +238,7 @@ def parse_llm_err(e: ResponseValidationError) -> LLMError: } err_list = [] - status_code: Optional[int] = None + status_code: int | None = None for candidate in candidates: reason: FinishReason = candidate.finish_reason diff --git a/src/unstract/sdk/adapters/ocr/google_document_ai/src/google_document_ai.py b/src/unstract/sdk/adapters/ocr/google_document_ai/src/google_document_ai.py index 2cfe039..3d76076 100644 --- a/src/unstract/sdk/adapters/ocr/google_document_ai/src/google_document_ai.py +++ b/src/unstract/sdk/adapters/ocr/google_document_ai/src/google_document_ai.py @@ -2,13 +2,12 @@ import json import logging import os -from typing import Any, Optional +from typing import Any import requests from filetype import filetype from google.auth.transport import requests as google_requests from google.oauth2.service_account import Credentials - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.ocr.constants import FileType from unstract.sdk.adapters.ocr.ocr_adapter import OCRAdapter @@ -59,8 +58,6 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/GoogleDocumentAI.png" - - """ Construct the request body to be sent to Google AI Document server """ def _get_request_body( @@ -113,7 +110,7 @@ def _get_input_file_type_mime( def process( self, input_file_path: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), ) -> str: try: diff --git a/src/unstract/sdk/adapters/ocr/ocr_adapter.py b/src/unstract/sdk/adapters/ocr/ocr_adapter.py index e685253..6cd4d96 100644 --- a/src/unstract/sdk/adapters/ocr/ocr_adapter.py +++ b/src/unstract/sdk/adapters/ocr/ocr_adapter.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Optional +from typing import Any from unstract.sdk.adapters.base import Adapter from unstract.sdk.adapters.enums import AdapterTypes @@ -30,9 +30,7 @@ def get_icon() -> str: def get_adapter_type() -> AdapterTypes: return AdapterTypes.OCR - def process( - self, input_file_path: str, output_file_path: Optional[str] = None - ) -> str: + def process(self, input_file_path: str, output_file_path: str | None = None) -> str: # Overriding methods will contain actual implementation return "" diff --git a/src/unstract/sdk/adapters/vectordb/__init__.py b/src/unstract/sdk/adapters/vectordb/__init__.py index e0252fe..531b6d1 100644 --- a/src/unstract/sdk/adapters/vectordb/__init__.py +++ b/src/unstract/sdk/adapters/vectordb/__init__.py @@ -1,6 +1,5 @@ from unstract.sdk.adapters import AdapterDict from unstract.sdk.adapters.vectordb.register import VectorDBRegistry - adapters: AdapterDict = {} VectorDBRegistry.register_adapters(adapters) diff --git a/src/unstract/sdk/adapters/vectordb/helper.py b/src/unstract/sdk/adapters/vectordb/helper.py index 64f4553..94ae138 100644 --- a/src/unstract/sdk/adapters/vectordb/helper.py +++ b/src/unstract/sdk/adapters/vectordb/helper.py @@ -1,6 +1,5 @@ import logging import os -from typing import Optional from llama_index.core import ( MockEmbedding, @@ -10,7 +9,6 @@ ) from llama_index.core.llms import MockLLM from llama_index.core.vector_stores.types import BasePydanticVectorStore - from unstract.sdk.adapters.vectordb.constants import VectorDbConstants from unstract.sdk.exceptions import VectorDBError @@ -20,7 +18,7 @@ class VectorDBHelper: @staticmethod def test_vector_db_instance( - vector_store: Optional[BasePydanticVectorStore], + vector_store: BasePydanticVectorStore | None, ) -> bool: try: if vector_store is None: @@ -64,8 +62,7 @@ def get_collection_name( collection_name_prefix: str, embedding_dimension: int, ) -> str: - """ - Notes: + """Notes: This function constructs the collection / table name to store the documents in the vector db. If user supplies this field in the config metadata then system diff --git a/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py b/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py index 5cdb85c..12c1fbc 100644 --- a/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py +++ b/src/unstract/sdk/adapters/vectordb/milvus/src/milvus.py @@ -1,10 +1,9 @@ import os -from typing import Any, Optional +from typing import Any from llama_index.core.vector_stores.types import VectorStore from llama_index.vector_stores.milvus import MilvusVectorStore from pymilvus import MilvusClient - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.vectordb.constants import VectorDbConstants from unstract.sdk.adapters.vectordb.helper import VectorDBHelper @@ -20,7 +19,7 @@ class Constants: class Milvus(VectorDBAdapter): def __init__(self, settings: dict[str, Any]): self._config = settings - self._client: Optional[MilvusClient] = None + self._client: MilvusClient | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME self._vector_db_instance = self._get_vector_db_instance() super().__init__("Milvus", self._vector_db_instance) @@ -43,8 +42,6 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/Milvus.png" - - def get_vector_db_instance(self) -> VectorStore: return self._vector_db_instance @@ -72,9 +69,7 @@ def _get_vector_db_instance(self) -> VectorStore: def test_connection(self) -> bool: vector_db = self.get_vector_db_instance() - test_result: bool = VectorDBHelper.test_vector_db_instance( - vector_store=vector_db - ) + test_result: bool = VectorDBHelper.test_vector_db_instance(vector_store=vector_db) # Delete the collection that was created for testing if self._client is not None: self._client.drop_collection(self._collection_name) diff --git a/src/unstract/sdk/adapters/vectordb/no_op/src/no_op_vectordb.py b/src/unstract/sdk/adapters/vectordb/no_op/src/no_op_vectordb.py index 032a3c8..93dfe7a 100644 --- a/src/unstract/sdk/adapters/vectordb/no_op/src/no_op_vectordb.py +++ b/src/unstract/sdk/adapters/vectordb/no_op/src/no_op_vectordb.py @@ -4,7 +4,6 @@ from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import VectorStore - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.vectordb.constants import VectorDbConstants from unstract.sdk.adapters.vectordb.helper import VectorDBHelper @@ -39,8 +38,6 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/noOpVectorDb.png" - - def get_vector_db_instance(self) -> VectorStore: return self._vector_db_instance diff --git a/src/unstract/sdk/adapters/vectordb/pinecone/src/pinecone.py b/src/unstract/sdk/adapters/vectordb/pinecone/src/pinecone.py index 3577bd2..0443498 100644 --- a/src/unstract/sdk/adapters/vectordb/pinecone/src/pinecone.py +++ b/src/unstract/sdk/adapters/vectordb/pinecone/src/pinecone.py @@ -1,14 +1,12 @@ import logging import os -from typing import Any, Optional +from typing import Any from llama_index.core.schema import BaseNode from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.pinecone import PineconeVectorStore -from pinecone import NotFoundException +from pinecone import NotFoundException, PodSpec, ServerlessSpec from pinecone import Pinecone as LLamaIndexPinecone -from pinecone import PodSpec, ServerlessSpec - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.vectordb.constants import VectorDbConstants from unstract.sdk.adapters.vectordb.helper import VectorDBHelper @@ -35,7 +33,7 @@ class Constants: class Pinecone(VectorDBAdapter): def __init__(self, settings: dict[str, Any]): self._config = settings - self._client: Optional[LLamaIndexPinecone] = None + self._client: LLamaIndexPinecone | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME self._vector_db_instance = self._get_vector_db_instance() super().__init__("Pinecone", self._vector_db_instance) @@ -58,8 +56,6 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/pinecone.png" - - def get_vector_db_instance(self) -> BasePydanticVectorStore: return self._vector_db_instance @@ -96,9 +92,7 @@ def _get_vector_db_instance(self) -> BasePydanticVectorStore: try: self._client.describe_index(name=self._collection_name) except NotFoundException: - logger.info( - f"Index:{self._collection_name} does not exist. Creating it." - ) + logger.info(f"Index:{self._collection_name} does not exist. Creating it.") self._client.create_index( name=self._collection_name, dimension=dimension, @@ -116,9 +110,7 @@ def _get_vector_db_instance(self) -> BasePydanticVectorStore: def test_connection(self) -> bool: vector_db = self.get_vector_db_instance() - test_result: bool = VectorDBHelper.test_vector_db_instance( - vector_store=vector_db - ) + test_result: bool = VectorDBHelper.test_vector_db_instance(vector_store=vector_db) # Delete the collection that was created for testing if self._client: self._client.delete_index(self._collection_name) diff --git a/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py b/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py index 0811d6d..3667612 100644 --- a/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py +++ b/src/unstract/sdk/adapters/vectordb/postgres/src/postgres.py @@ -1,12 +1,11 @@ import os -from typing import Any, Optional +from typing import Any from urllib.parse import quote_plus import psycopg2 from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.postgres import PGVectorStore from psycopg2._psycopg import connection - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.vectordb.constants import VectorDbConstants from unstract.sdk.adapters.vectordb.helper import VectorDBHelper @@ -26,7 +25,7 @@ class Constants: class Postgres(VectorDBAdapter): def __init__(self, settings: dict[str, Any]): self._config = settings - self._client: Optional[connection] = None + self._client: connection | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME self._schema_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME self._vector_db_instance = self._get_vector_db_instance() @@ -50,8 +49,6 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/postgres.png" - - def get_vector_db_instance(self) -> BasePydanticVectorStore: return self._vector_db_instance @@ -99,9 +96,7 @@ def _get_vector_db_instance(self) -> BasePydanticVectorStore: def test_connection(self) -> bool: vector_db = self.get_vector_db_instance() - test_result: bool = VectorDBHelper.test_vector_db_instance( - vector_store=vector_db - ) + test_result: bool = VectorDBHelper.test_vector_db_instance(vector_store=vector_db) # Delete the collection that was created for testing if self._client is not None: diff --git a/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py b/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py index aa1a508..767a387 100644 --- a/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py +++ b/src/unstract/sdk/adapters/vectordb/qdrant/src/qdrant.py @@ -1,12 +1,11 @@ import logging import os -from typing import Any, Optional +from typing import Any from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.qdrant import QdrantVectorStore from qdrant_client import QdrantClient from qdrant_client.http.exceptions import UnexpectedResponse - from unstract.sdk.adapters.vectordb.constants import VectorDbConstants from unstract.sdk.adapters.vectordb.helper import VectorDBHelper from unstract.sdk.adapters.vectordb.vectordb_adapter import VectorDBAdapter @@ -23,7 +22,7 @@ class Constants: class Qdrant(VectorDBAdapter): def __init__(self, settings: dict[str, Any]): self._config = settings - self._client: Optional[QdrantClient] = None + self._client: QdrantClient | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME self._vector_db_instance = self._get_vector_db_instance() super().__init__("Qdrant", self._vector_db_instance) @@ -46,8 +45,6 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/qdrant.png" - - def get_vector_db_instance(self) -> BasePydanticVectorStore: return self._vector_db_instance @@ -58,7 +55,7 @@ def _get_vector_db_instance(self) -> BasePydanticVectorStore: self._config.get(VectorDbConstants.EMBEDDING_DIMENSION), ) url = self._config.get(Constants.URL) - api_key: Optional[str] = self._config.get(Constants.API_KEY, None) + api_key: str | None = self._config.get(Constants.API_KEY, None) if api_key: self._client = QdrantClient(url=url, api_key=api_key) else: diff --git a/src/unstract/sdk/adapters/vectordb/supabase/src/supabase.py b/src/unstract/sdk/adapters/vectordb/supabase/src/supabase.py index e5158db..aca5749 100644 --- a/src/unstract/sdk/adapters/vectordb/supabase/src/supabase.py +++ b/src/unstract/sdk/adapters/vectordb/supabase/src/supabase.py @@ -1,15 +1,14 @@ import os -from typing import Any, Optional +from typing import Any from urllib.parse import quote_plus from llama_index.core.vector_stores.types import VectorStore from llama_index.vector_stores.supabase import SupabaseVectorStore -from vecs import Client - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.vectordb.constants import VectorDbConstants from unstract.sdk.adapters.vectordb.helper import VectorDBHelper from unstract.sdk.adapters.vectordb.vectordb_adapter import VectorDBAdapter +from vecs import Client class Constants: @@ -24,7 +23,7 @@ class Constants: class Supabase(VectorDBAdapter): def __init__(self, settings: dict[str, Any]): self._config = settings - self._client: Optional[Client] = None + self._client: Client | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME self._vector_db_instance = self._get_vector_db_instance() super().__init__("Supabase", self._vector_db_instance) @@ -47,8 +46,6 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/supabase.png" - - def get_vector_db_instance(self) -> VectorStore: return self._vector_db_instance @@ -88,9 +85,7 @@ def _get_vector_db_instance(self) -> VectorStore: def test_connection(self) -> bool: vector_db = self.get_vector_db_instance() - test_result: bool = VectorDBHelper.test_vector_db_instance( - vector_store=vector_db - ) + test_result: bool = VectorDBHelper.test_vector_db_instance(vector_store=vector_db) # Delete the collection that was created for testing if self._client is not None: self._client.delete_collection(self._collection_name) diff --git a/src/unstract/sdk/adapters/vectordb/vectordb_adapter.py b/src/unstract/sdk/adapters/vectordb/vectordb_adapter.py index e3240bd..7db3923 100644 --- a/src/unstract/sdk/adapters/vectordb/vectordb_adapter.py +++ b/src/unstract/sdk/adapters/vectordb/vectordb_adapter.py @@ -1,10 +1,9 @@ from abc import ABC -from typing import Any, Union +from typing import Any from llama_index.core.schema import BaseNode from llama_index.core.vector_stores import SimpleVectorStore from llama_index.core.vector_stores.types import BasePydanticVectorStore, VectorStore - from unstract.sdk.adapters.base import Adapter from unstract.sdk.adapters.enums import AdapterTypes from unstract.sdk.exceptions import VectorDBError @@ -14,13 +13,13 @@ class VectorDBAdapter(Adapter, ABC): def __init__( self, name: str, - vector_db_instance: Union[VectorStore, BasePydanticVectorStore], + vector_db_instance: VectorStore | BasePydanticVectorStore, ): super().__init__(name) self.name = name - self._vector_db_instance: Union[ - VectorStore, BasePydanticVectorStore - ] = vector_db_instance + self._vector_db_instance: VectorStore | BasePydanticVectorStore = ( + vector_db_instance + ) @staticmethod def get_id() -> str: @@ -58,7 +57,7 @@ def parse_vector_db_err(e: Exception) -> VectorDBError: def get_vector_db_instance( self, vector_db_config: dict[str, Any] - ) -> Union[BasePydanticVectorStore, VectorStore]: + ) -> BasePydanticVectorStore | VectorStore: """Instantiate the llama index VectorStore / BasePydanticVectorStore class. diff --git a/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py b/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py index fce926c..8e2c12a 100644 --- a/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py +++ b/src/unstract/sdk/adapters/vectordb/weaviate/src/weaviate.py @@ -1,17 +1,16 @@ import logging import os -from typing import Any, Optional +from typing import Any import weaviate from llama_index.core.vector_stores.types import BasePydanticVectorStore from llama_index.vector_stores.weaviate import WeaviateVectorStore -from weaviate.classes.init import Auth -from weaviate.exceptions import UnexpectedStatusCodeException - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.vectordb.constants import VectorDbConstants from unstract.sdk.adapters.vectordb.helper import VectorDBHelper from unstract.sdk.adapters.vectordb.vectordb_adapter import VectorDBAdapter +from weaviate.classes.init import Auth +from weaviate.exceptions import UnexpectedStatusCodeException logger = logging.getLogger(__name__) @@ -24,7 +23,7 @@ class Constants: class Weaviate(VectorDBAdapter): def __init__(self, settings: dict[str, Any]): self._config = settings - self._client: Optional[weaviate.Client] = None + self._client: weaviate.Client | None = None self._collection_name: str = VectorDbConstants.DEFAULT_VECTOR_DB_NAME self._vector_db_instance = self._get_vector_db_instance() super().__init__("Weaviate", self._vector_db_instance) @@ -47,8 +46,6 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/Weaviate.png" - - def get_vector_db_instance(self) -> BasePydanticVectorStore: return self._vector_db_instance @@ -92,9 +89,7 @@ def _get_vector_db_instance(self) -> BasePydanticVectorStore: def test_connection(self) -> bool: vector_db = self.get_vector_db_instance() - test_result: bool = VectorDBHelper.test_vector_db_instance( - vector_store=vector_db - ) + test_result: bool = VectorDBHelper.test_vector_db_instance(vector_store=vector_db) # Delete the collection that was created for testing if self._client is not None: self._client.collections.delete(self._collection_name) diff --git a/src/unstract/sdk/adapters/x2text/dto.py b/src/unstract/sdk/adapters/x2text/dto.py index efd2cd7..8fa226e 100644 --- a/src/unstract/sdk/adapters/x2text/dto.py +++ b/src/unstract/sdk/adapters/x2text/dto.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional @dataclass @@ -10,4 +9,4 @@ class TextExtractionMetadata: @dataclass class TextExtractionResult: extracted_text: str - extraction_metadata: Optional[TextExtractionMetadata] = None + extraction_metadata: TextExtractionMetadata | None = None diff --git a/src/unstract/sdk/adapters/x2text/helper.py b/src/unstract/sdk/adapters/x2text/helper.py index 28667f0..3a94872 100644 --- a/src/unstract/sdk/adapters/x2text/helper.py +++ b/src/unstract/sdk/adapters/x2text/helper.py @@ -1,10 +1,9 @@ import logging -from typing import Any, Optional +from typing import Any import requests from requests import Response from requests.exceptions import ConnectionError, HTTPError, Timeout - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.utils import AdapterUtils from unstract.sdk.adapters.x2text.constants import X2TextConstants @@ -20,7 +19,7 @@ class X2TextHelper: @staticmethod def parse_response( response: Response, - out_file_path: Optional[str] = None, + out_file_path: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), ) -> tuple[str, bool]: """Parses the response from a request. @@ -64,7 +63,7 @@ def test_server_connection(unstructured_adapter_config: dict[str, Any]) -> bool: def process_document( unstructured_adapter_config: dict[str, Any], input_file_path: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), ) -> str: try: @@ -102,12 +101,8 @@ def make_request( ) -> Response: unstructured_url = unstructured_adapter_config.get(UnstructuredHelper.URL) - x2text_service_url = unstructured_adapter_config.get( - X2TextConstants.X2TEXT_HOST - ) - x2text_service_port = unstructured_adapter_config.get( - X2TextConstants.X2TEXT_PORT - ) + x2text_service_url = unstructured_adapter_config.get(X2TextConstants.X2TEXT_HOST) + x2text_service_port = unstructured_adapter_config.get(X2TextConstants.X2TEXT_PORT) platform_service_api_key = unstructured_adapter_config.get( X2TextConstants.PLATFORM_SERVICE_API_KEY ) @@ -124,23 +119,19 @@ def make_request( body["unstructured-api-key"] = api_key x2text_url = ( - f"{x2text_service_url}:{x2text_service_port}" - f"/api/v1/x2text/{request_type}" + f"{x2text_service_url}:{x2text_service_port}" f"/api/v1/x2text/{request_type}" ) # Add files only if the request is for process files = None if "files" in kwargs: files = kwargs["files"] if kwargs["files"] is not None else None try: - response = requests.post( - x2text_url, headers=headers, data=body, files=files - ) + response = requests.post(x2text_url, headers=headers, data=body, files=files) response.raise_for_status() except ConnectionError as e: logger.error(f"Adapter error: {e}") raise AdapterError( - "Unable to connect to unstructured-io's service, " - "please check the URL" + "Unable to connect to unstructured-io's service, " "please check the URL" ) except Timeout as e: msg = "Request to unstructured-io's service has timed out" diff --git a/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py b/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py index 280f1d5..237a507 100644 --- a/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py +++ b/src/unstract/sdk/adapters/x2text/llama_parse/src/llama_parse.py @@ -1,11 +1,10 @@ import logging import os import pathlib -from typing import Any, Optional +from typing import Any from httpx import ConnectError from llama_parse import LlamaParse - from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.x2text.dto import TextExtractionResult from unstract.sdk.adapters.x2text.llama_parse.src.constants import LlamaParseConfig @@ -38,8 +37,6 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/llama-parse.png" - - def _call_parser( self, input_file_path: str, @@ -84,8 +81,7 @@ def _call_parser( except ConnectError as connec_err: logger.error(f"Invalid Base URL given. : {connec_err}") raise AdapterError( - "Unable to connect to llama-parse`s service, " - "please check the Base URL" + "Unable to connect to llama-parse`s service, " "please check the Base URL" ) except Exception as exe: logger.error( @@ -99,7 +95,7 @@ def _call_parser( def process( self, input_file_path: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), **kwargs: dict[Any, Any], ) -> TextExtractionResult: diff --git a/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py b/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py index cfa5517..71ef450 100644 --- a/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py +++ b/src/unstract/sdk/adapters/x2text/llm_whisperer/src/llm_whisperer.py @@ -3,12 +3,11 @@ import os import time from pathlib import Path -from typing import Any, Optional +from typing import Any import requests from requests import Response from requests.exceptions import ConnectionError, HTTPError, Timeout - from unstract.sdk.adapters.exceptions import ExtractorError from unstract.sdk.adapters.utils import AdapterUtils from unstract.sdk.adapters.x2text.constants import X2TextConstants @@ -56,8 +55,6 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/LLMWhisperer.png" - - def _get_request_headers(self) -> dict[str, Any]: """Obtains the request headers to authenticate with LLMWhisperer. @@ -73,9 +70,9 @@ def _make_request( self, request_method: HTTPMethod, request_endpoint: str, - headers: Optional[dict[str, Any]] = None, - params: Optional[dict[str, Any]] = None, - data: Optional[Any] = None, + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + data: Any | None = None, ) -> Response: """Makes a request to LLMWhisperer service. @@ -329,7 +326,7 @@ def _send_whisper_request( def _extract_text_from_response( self, - output_file_path: Optional[str], + output_file_path: str | None, response: requests.Response, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), ) -> str: @@ -396,9 +393,7 @@ def _write_output_to_file( data=metadata_json, ) except Exception as e: - logger.error( - f"Error while writing metadata to {metadata_file_path}: {e}" - ) + logger.error(f"Error while writing metadata to {metadata_file_path}: {e}") except Exception as e: logger.error(f"Error while writing {output_file_path}: {e}") @@ -407,7 +402,7 @@ def _write_output_to_file( def process( self, input_file_path: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), **kwargs: dict[Any, Any], ) -> TextExtractionResult: @@ -422,7 +417,6 @@ def process( Returns: str: Extracted text """ - response: requests.Response = self._send_whisper_request( input_file_path, fs, diff --git a/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/dto.py b/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/dto.py index 5fd7974..11bc036 100644 --- a/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/dto.py +++ b/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/dto.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from typing import Optional, Union, List @dataclass @@ -13,7 +12,7 @@ class WhispererRequestParams: """ # TODO: Extend this DTO to include all Whisperer API parameters - tag: Optional[Union[str, List[str]]] = None + tag: str | list[str] | None = None enable_highlight: bool = False def __post_init__(self) -> None: diff --git a/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/helper.py b/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/helper.py index 9135a93..ce90025 100644 --- a/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/helper.py +++ b/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/helper.py @@ -2,7 +2,7 @@ import logging from io import BytesIO from pathlib import Path -from typing import Any, Optional +from typing import Any import requests from requests import Response @@ -11,9 +11,9 @@ LLMWhispererClientException, LLMWhispererClientV2, ) - from unstract.sdk.adapters.exceptions import ExtractorError from unstract.sdk.adapters.utils import AdapterUtils +from unstract.sdk.adapters.x2text.constants import X2TextConstants from unstract.sdk.adapters.x2text.llm_whisperer_v2.src.constants import ( Modes, OutputModes, @@ -22,7 +22,6 @@ WhispererHeader, WhisperStatus, ) -from unstract.sdk.adapters.x2text.constants import X2TextConstants from unstract.sdk.adapters.x2text.llm_whisperer_v2.src.dto import WhispererRequestParams from unstract.sdk.constants import MimeType from unstract.sdk.file_storage import FileStorage, FileStorageProvider @@ -77,10 +76,10 @@ def test_connection_request( @staticmethod def make_request( config: dict[str, Any], - headers: Optional[dict[str, Any]] = None, - params: Optional[dict[str, Any]] = None, - data: Optional[Any] = None, - type: str = "whisper" + headers: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + data: Any | None = None, + type: str = "whisper", ) -> Response: """Makes a request to LLMWhisperer service. @@ -111,8 +110,8 @@ def make_request( if type == "whisper": response = client.whisper(**params, stream=data) if response["status_code"] == 200: - response["extraction"][X2TextConstants.WHISPER_HASH_V2] = response.get( - X2TextConstants.WHISPER_HASH_V2, "" + response["extraction"][X2TextConstants.WHISPER_HASH_V2] = ( + response.get(X2TextConstants.WHISPER_HASH_V2, "") ) return response["extraction"] else: @@ -257,9 +256,7 @@ def send_whisper_request( @staticmethod def make_highlight_data_request( - config: dict[str, Any], - whisper_hash: str, - enable_highlight: bool + config: dict[str, Any], whisper_hash: str, enable_highlight: bool ) -> dict[Any, Any]: """Makes a call to get highlight data from LLMWhisperer. @@ -285,10 +282,10 @@ def make_highlight_data_request( type="highlight", ) return retrieve_response - + @staticmethod def extract_text_from_response( - output_file_path: Optional[str], + output_file_path: str | None, response: dict[str, Any], fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), ) -> str: @@ -355,4 +352,4 @@ def write_output_to_file( encoding="utf-8", ) except Exception as e: - logger.warn(f"Error while writing metadata to {metadata_file_path}: {e}") \ No newline at end of file + logger.warn(f"Error while writing metadata to {metadata_file_path}: {e}") diff --git a/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py b/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py index 2b5b06b..166b9f9 100644 --- a/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py +++ b/src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py @@ -1,9 +1,8 @@ import logging import os -from typing import Any, Optional +from typing import Any import requests - from unstract.sdk.adapters.x2text.constants import X2TextConstants from unstract.sdk.adapters.x2text.dto import ( TextExtractionMetadata, @@ -53,7 +52,7 @@ def test_connection(self) -> bool: def process( self, input_file_path: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), **kwargs: dict[Any, Any], ) -> TextExtractionResult: @@ -68,7 +67,6 @@ def process( Returns: str: Extracted text """ - enable_highlight = kwargs.get(X2TextConstants.ENABLE_HIGHLIGHT, False) extra_params = WhispererRequestParams( tag=kwargs.get(X2TextConstants.TAGS), @@ -86,7 +84,9 @@ def process( return TextExtractionResult( extracted_text=LLMWhispererHelper.extract_text_from_response( - output_file_path, response, fs=fs, + output_file_path, + response, + fs=fs, ), extraction_metadata=metadata, ) diff --git a/src/unstract/sdk/adapters/x2text/no_op/src/no_op_x2text.py b/src/unstract/sdk/adapters/x2text/no_op/src/no_op_x2text.py index b79df2b..6366f44 100644 --- a/src/unstract/sdk/adapters/x2text/no_op/src/no_op_x2text.py +++ b/src/unstract/sdk/adapters/x2text/no_op/src/no_op_x2text.py @@ -1,7 +1,7 @@ import logging import os import time -from typing import Any, Optional +from typing import Any from unstract.sdk.adapters.x2text.dto import TextExtractionResult from unstract.sdk.adapters.x2text.x2text_adapter import X2TextAdapter @@ -33,12 +33,10 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/noOpx2Text.png" - - def process( self, input_file_path: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), **kwargs: dict[Any, Any], ) -> TextExtractionResult: diff --git a/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py b/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py index 46e8ff9..b205ce0 100644 --- a/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py +++ b/src/unstract/sdk/adapters/x2text/unstructured_community/src/unstructured_community.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Optional +from typing import Any from unstract.sdk.adapters.x2text.dto import TextExtractionResult from unstract.sdk.adapters.x2text.helper import UnstructuredHelper @@ -33,12 +33,10 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/UnstructuredIO.png" - - def process( self, input_file_path: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), **kwargs: dict[Any, Any], ) -> TextExtractionResult: diff --git a/src/unstract/sdk/adapters/x2text/unstructured_enterprise/src/unstructured_enterprise.py b/src/unstract/sdk/adapters/x2text/unstructured_enterprise/src/unstructured_enterprise.py index f3bb479..908be6d 100644 --- a/src/unstract/sdk/adapters/x2text/unstructured_enterprise/src/unstructured_enterprise.py +++ b/src/unstract/sdk/adapters/x2text/unstructured_enterprise/src/unstructured_enterprise.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Optional +from typing import Any from unstract.sdk.adapters.x2text.dto import TextExtractionResult from unstract.sdk.adapters.x2text.helper import UnstructuredHelper @@ -33,12 +33,10 @@ def get_description() -> str: def get_icon() -> str: return "/icons/adapter-icons/UnstructuredIO.png" - - def process( self, input_file_path: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), **kwargs: dict[str, Any], ) -> TextExtractionResult: diff --git a/src/unstract/sdk/adapters/x2text/x2text_adapter.py b/src/unstract/sdk/adapters/x2text/x2text_adapter.py index 8fe9a1d..3516fc7 100644 --- a/src/unstract/sdk/adapters/x2text/x2text_adapter.py +++ b/src/unstract/sdk/adapters/x2text/x2text_adapter.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Optional +from typing import Any from unstract.sdk.adapters.base import Adapter from unstract.sdk.adapters.enums import AdapterTypes @@ -38,7 +38,7 @@ def test_connection(self) -> bool: def process( self, input_file_path: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), **kwargs: dict[Any, Any], ) -> TextExtractionResult: diff --git a/src/unstract/sdk/audit.py b/src/unstract/sdk/audit.py index 9544594..bff3df3 100644 --- a/src/unstract/sdk/audit.py +++ b/src/unstract/sdk/audit.py @@ -1,8 +1,7 @@ -from typing import Any, Union +from typing import Any import requests from llama_index.core.callbacks import CBEventType, TokenCountingHandler - from unstract.sdk.constants import LogLevel, ToolEnv from unstract.sdk.helper import SdkHelper from unstract.sdk.tool.stream import StreamMixin @@ -26,7 +25,7 @@ def __init__(self, log_level: LogLevel = LogLevel.INFO) -> None: def push_usage_data( self, platform_api_key: str, - token_counter: Union[TokenCountingHandler, TokenCounter] = None, + token_counter: TokenCountingHandler | TokenCounter = None, model_name: str = "", event_type: CBEventType = None, kwargs: dict[Any, Any] = None, diff --git a/src/unstract/sdk/cache.py b/src/unstract/sdk/cache.py index 5514860..cd515c3 100644 --- a/src/unstract/sdk/cache.py +++ b/src/unstract/sdk/cache.py @@ -1,7 +1,6 @@ -from typing import Any, Optional +from typing import Any import requests - from unstract.sdk.constants import LogLevel from unstract.sdk.platform import PlatformBase from unstract.sdk.tool.base import BaseTool @@ -14,11 +13,8 @@ class ToolCache(PlatformBase): - PLATFORM_SERVICE_API_KEY environment variable is required. """ - def __init__( - self, tool: BaseTool, platform_host: str, platform_port: int - ) -> None: - """ - Args: + def __init__(self, tool: BaseTool, platform_host: str, platform_port: int) -> None: + """Args: tool (AbstractTool): Instance of AbstractTool platform_host (str): The host of the platform. platform_port (int): The port of the platform. @@ -42,7 +38,6 @@ def set(self, key: str, value: str) -> bool: Returns: bool: Whether the operation was successful. """ - url = f"{self.base_url}/cache" json = {"key": key, "value": value} headers = {"Authorization": f"Bearer {self.bearer_token}"} @@ -58,7 +53,7 @@ def set(self, key: str, value: str) -> bool: ) return False - def get(self, key: str) -> Optional[Any]: + def get(self, key: str) -> Any | None: """Gets the value for a key in the cache. Args: @@ -67,20 +62,15 @@ def get(self, key: str) -> Optional[Any]: Returns: str: The value. """ - url = f"{self.base_url}/cache?key={key}" headers = {"Authorization": f"Bearer {self.bearer_token}"} response = requests.get(url, headers=headers) if response.status_code == 200: - self.tool.stream_log( - f"Successfully retrieved cached data for key: {key}" - ) + self.tool.stream_log(f"Successfully retrieved cached data for key: {key}") return response.text elif response.status_code == 404: - self.tool.stream_log( - f"Data not found for key: {key}", level=LogLevel.WARN - ) + self.tool.stream_log(f"Data not found for key: {key}", level=LogLevel.WARN) return None else: self.tool.stream_log( @@ -104,14 +94,11 @@ def delete(self, key: str) -> bool: response = requests.delete(url, headers=headers) if response.status_code == 200: - self.tool.stream_log( - f"Successfully deleted cached data for key: {key}" - ) + self.tool.stream_log(f"Successfully deleted cached data for key: {key}") return True else: self.tool.stream_log( - "Error while deleting cached data " - f"for key: {key} / {response.reason}", + "Error while deleting cached data " f"for key: {key} / {response.reason}", level=LogLevel.ERROR, ) return False diff --git a/src/unstract/sdk/embedding.py b/src/unstract/sdk/embedding.py index 2e8c3d2..71f2f5f 100644 --- a/src/unstract/sdk/embedding.py +++ b/src/unstract/sdk/embedding.py @@ -1,10 +1,9 @@ -from typing import Any, Optional +from typing import Any from deprecated import deprecated from llama_index.core.base.embeddings.base import Embedding from llama_index.core.callbacks import CallbackManager as LlamaIndexCallbackManager from llama_index.core.embeddings import BaseEmbedding - from unstract.sdk.adapter import ToolAdapter from unstract.sdk.adapters.constants import Common from unstract.sdk.adapters.embedding import adapters @@ -23,7 +22,7 @@ class Embedding: def __init__( self, tool: BaseTool, - adapter_instance_id: Optional[str] = None, + adapter_instance_id: str | None = None, usage_kwargs: dict[Any, Any] = {}, ): self._tool = tool @@ -88,9 +87,7 @@ def get_query_embedding(self, query: str) -> Embedding: return self._embedding_instance.get_query_embedding(query) def _get_embedding_length(self) -> int: - embedding_list = self._embedding_instance._get_text_embedding( - self._TEST_SNIPPET - ) + embedding_list = self._embedding_instance._get_text_embedding(self._TEST_SNIPPET) embedding_dimension = len(embedding_list) return embedding_dimension @@ -100,7 +97,7 @@ def get_class_name(self) -> str: Args: NA - Returns: + Returns: Class name """ return self._embedding_instance.class_name() @@ -111,7 +108,7 @@ def get_callback_manager(self) -> LlamaIndexCallbackManager: Args: NA - Returns: + Returns: llama-index callback manager """ return self._embedding_instance.callback_manager diff --git a/src/unstract/sdk/exceptions.py b/src/unstract/sdk/exceptions.py index b5dd375..db7c560 100644 --- a/src/unstract/sdk/exceptions.py +++ b/src/unstract/sdk/exceptions.py @@ -1,6 +1,3 @@ -from typing import Optional - - def resolve_err_status_code(client_status_code: int) -> int: """Resolves the status code to return in case of errors. @@ -25,14 +22,14 @@ def resolve_err_status_code(client_status_code: int) -> int: class SdkError(Exception): DEFAULT_MESSAGE = "Something went wrong" - actual_err: Optional[Exception] = None - status_code: Optional[int] = None + actual_err: Exception | None = None + status_code: int | None = None def __init__( self, message: str = DEFAULT_MESSAGE, - status_code: Optional[int] = None, - actual_err: Optional[Exception] = None, + status_code: int | None = None, + actual_err: Exception | None = None, ): super().__init__(message) # Make it user friendly wherever possible diff --git a/src/unstract/sdk/file_storage/helper.py b/src/unstract/sdk/file_storage/helper.py index 2f8abb8..afb4aa3 100644 --- a/src/unstract/sdk/file_storage/helper.py +++ b/src/unstract/sdk/file_storage/helper.py @@ -3,7 +3,6 @@ import fsspec from fsspec import AbstractFileSystem - from unstract.sdk.exceptions import FileOperationError, FileStorageError from unstract.sdk.file_storage.provider import FileStorageProvider @@ -25,7 +24,6 @@ def file_storage_init( Returns: NA """ - try: protocol = provider.value if provider == FileStorageProvider.LOCAL: diff --git a/src/unstract/sdk/file_storage/impl.py b/src/unstract/sdk/file_storage/impl.py index e837876..64cdfd0 100644 --- a/src/unstract/sdk/file_storage/impl.py +++ b/src/unstract/sdk/file_storage/impl.py @@ -2,13 +2,12 @@ import logging from datetime import datetime from hashlib import sha256 -from typing import Any, Union +from typing import Any import filetype import fsspec import magic import yaml - from unstract.sdk.exceptions import FileOperationError from unstract.sdk.file_storage.constants import FileOperationParams, FileSeekPosition from unstract.sdk.file_storage.helper import FileStorageHelper, skip_local_cache @@ -36,7 +35,7 @@ def read( encoding: str = FileOperationParams.DEFAULT_ENCODING, seek_position: int = 0, length: int = FileOperationParams.READ_ENTIRE_LENGTH, - ) -> Union[bytes, str]: + ) -> bytes | str: """Read the file pointed to by the file_handle. Args: @@ -62,7 +61,7 @@ def write( mode: str, encoding: str = FileOperationParams.DEFAULT_ENCODING, seek_position: int = 0, - data: Union[bytes, str] = "", + data: bytes | str = "", ) -> int: """Write data in the file pointed to by the file-handle. @@ -306,7 +305,6 @@ def get_hash_from_file(self, path: str) -> str: Returns: str: SHA256 hash of the file """ - h = sha256() b = bytearray(128 * 1024) mv = memoryview(b) diff --git a/src/unstract/sdk/file_storage/interface.py b/src/unstract/sdk/file_storage/interface.py index 3b3e4e9..8fd4624 100644 --- a/src/unstract/sdk/file_storage/interface.py +++ b/src/unstract/sdk/file_storage/interface.py @@ -1,9 +1,8 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Any, Union +from typing import Any from fsspec import AbstractFileSystem - from unstract.sdk.file_storage.constants import FileOperationParams, FileSeekPosition @@ -16,7 +15,7 @@ def read( encoding: str = FileOperationParams.DEFAULT_ENCODING, seek_position: int = 0, length: int = FileOperationParams.READ_ENTIRE_LENGTH, - ) -> Union[bytes, str]: + ) -> bytes | str: pass @abstractmethod @@ -26,14 +25,14 @@ def write( mode: str, encoding: str = FileOperationParams.DEFAULT_ENCODING, seek_position: int = 0, - data: Union[bytes, str] = "", + data: bytes | str = "", ) -> int: pass @abstractmethod def seek( self, - file_handle: Union[AbstractFileSystem], + file_handle: AbstractFileSystem, location: int = 0, position: FileSeekPosition = FileSeekPosition.START, ) -> int: diff --git a/src/unstract/sdk/file_storage/permanent.py b/src/unstract/sdk/file_storage/permanent.py index 1745deb..d1ccf42 100644 --- a/src/unstract/sdk/file_storage/permanent.py +++ b/src/unstract/sdk/file_storage/permanent.py @@ -1,9 +1,8 @@ import logging -from typing import Any, Optional, Union +from typing import Any import filetype import magic - from unstract.sdk.exceptions import FileOperationError, FileStorageError from unstract.sdk.file_storage.constants import FileOperationParams from unstract.sdk.file_storage.impl import FileStorage @@ -77,8 +76,8 @@ def read( encoding: str = FileOperationParams.DEFAULT_ENCODING, seek_position: int = 0, length: int = FileOperationParams.READ_ENTIRE_LENGTH, - legacy_storage_path: Optional[str] = None, - ) -> Union[bytes, str]: + legacy_storage_path: str | None = None, + ) -> bytes | str: """Read the file pointed to by the file_handle. Args: @@ -108,7 +107,7 @@ def mime_type( self, path: str, read_length: int = FileOperationParams.READ_ENTIRE_LENGTH, - legacy_storage_path: Optional[str] = None, + legacy_storage_path: str | None = None, ) -> str: """Gets the file MIME type for an input file. Uses libmagic to perform the same. @@ -131,9 +130,7 @@ def mime_type( mime_type = magic.from_buffer(sample_contents, mime=True) return mime_type - def guess_extension( - self, path: str, legacy_storage_path: Optional[str] = None - ) -> str: + def guess_extension(self, path: str, legacy_storage_path: str | None = None) -> str: """Returns the extension of the file passed. Args: diff --git a/src/unstract/sdk/index.py b/src/unstract/sdk/index.py index cbfd5c3..7392911 100644 --- a/src/unstract/sdk/index.py +++ b/src/unstract/sdk/index.py @@ -1,6 +1,7 @@ import json import logging -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from deprecated import deprecated from llama_index.core import Document @@ -12,7 +13,6 @@ VectorStoreQuery, VectorStoreQueryResult, ) - from unstract.sdk.adapter import ToolAdapter from unstract.sdk.adapters.exceptions import AdapterError from unstract.sdk.adapters.vectordb.no_op.src.no_op_custom_vectordb import ( @@ -43,7 +43,7 @@ class Index: def __init__( self, tool: BaseTool, - run_id: Optional[str] = None, + run_id: str | None = None, capture_metrics: bool = False, ): # TODO: Inherit from StreamMixin and avoid using BaseTool @@ -124,12 +124,12 @@ def extract_text( self, x2text_instance_id: str, file_path: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, enable_highlight: bool = False, usage_kwargs: dict[Any, Any] = {}, - process_text: Optional[Callable[[str], str]] = None, + process_text: Callable[[str], str] | None = None, fs: FileStorage = FileStorage(FileStorageProvider.LOCAL), - tags: Optional[list[str]] = None, + tags: list[str] | None = None, ) -> str: """Extracts text from a document. @@ -218,13 +218,13 @@ def index( chunk_size: int, chunk_overlap: int, reindex: bool = False, - file_hash: Optional[str] = None, - output_file_path: Optional[str] = None, + file_hash: str | None = None, + output_file_path: str | None = None, enable_highlight: bool = False, usage_kwargs: dict[Any, Any] = {}, - process_text: Optional[Callable[[str], str]] = None, + process_text: Callable[[str], str] | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), - tags: Optional[list[str]] = None, + tags: list[str] | None = None, ) -> str: """Indexes an individual file using the passed arguments. @@ -448,8 +448,8 @@ def generate_index_key( x2text: str, chunk_size: str, chunk_overlap: str, - file_path: Optional[str] = None, - file_hash: Optional[str] = None, + file_path: str | None = None, + file_hash: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), ) -> str: """Generates a unique ID useful for identifying files during indexing. @@ -508,8 +508,8 @@ def generate_file_id( x2text: str, chunk_size: str, chunk_overlap: str, - file_path: Optional[str] = None, - file_hash: Optional[str] = None, + file_path: str | None = None, + file_hash: str | None = None, ) -> str: return self.generate_index_key( vector_db, @@ -533,8 +533,8 @@ def index_file( chunk_size: int, chunk_overlap: int, reindex: bool = False, - file_hash: Optional[str] = None, - output_file_path: Optional[str] = None, + file_hash: str | None = None, + output_file_path: str | None = None, ) -> str: return self.index( tool_id=tool_id, @@ -552,7 +552,7 @@ def index_file( @deprecated("Deprecated class and method. Use Index and query_index() instead") def get_text_from_index( self, embedding_type: str, vector_db: str, doc_id: str - ) -> Optional[str]: + ) -> str | None: return self.query_index( embedding_instance_id=embedding_type, vector_db_instance_id=vector_db, diff --git a/src/unstract/sdk/llm.py b/src/unstract/sdk/llm.py index f7c18c5..4b9bbcb 100644 --- a/src/unstract/sdk/llm.py +++ b/src/unstract/sdk/llm.py @@ -1,6 +1,7 @@ import logging import re -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any from deprecated import deprecated from llama_index.core.base.llms.types import CompletionResponseGen @@ -8,7 +9,6 @@ from llama_index.core.llms import CompletionResponse from openai import APIError as OpenAIAPIError from openai import RateLimitError as OpenAIRateLimitError - from unstract.sdk.adapter import ToolAdapter from unstract.sdk.adapters.constants import Common from unstract.sdk.adapters.llm import adapters @@ -35,7 +35,7 @@ class LLM: def __init__( self, tool: BaseTool, - adapter_instance_id: Optional[str] = None, + adapter_instance_id: str | None = None, usage_kwargs: dict[Any, Any] = {}, capture_metrics: bool = False, ): @@ -76,7 +76,7 @@ def complete( self, prompt: str, extract_json: bool = True, - process_text: Optional[Callable[[str], str]] = None, + process_text: Callable[[str], str] | None = None, **kwargs: Any, ) -> dict[str, Any]: """Generates a completion response for the given prompt and captures @@ -176,7 +176,7 @@ def get_max_tokens(self, reserved_for_output: int = 0) -> int: output. The default is 0. - Returns: + Returns: int: The maximum number of tokens that can be used for the LLM. """ return self.MAX_TOKENS - reserved_for_output @@ -187,7 +187,7 @@ def set_max_tokens(self, max_tokens: int) -> None: Args: max_tokens (int): The number of tokens to be used at the maximum - Returns: + Returns: None """ self._llm_instance.max_tokens = max_tokens @@ -198,7 +198,7 @@ def get_class_name(self) -> str: Args: NA - Returns: + Returns: Class name """ return self._llm_instance.class_name() @@ -215,7 +215,7 @@ def get_model_name(self) -> str: return self._llm_instance.model @deprecated("Use LLM instead of ToolLLM") - def get_llm(self, adapter_instance_id: Optional[str] = None) -> LlamaIndexLLM: + def get_llm(self, adapter_instance_id: str | None = None) -> LlamaIndexLLM: if not self._llm_instance: self._adapter_instance_id = adapter_instance_id self._initialise() @@ -230,7 +230,7 @@ def run_completion( prompt: str, retries: int = 3, **kwargs: Any, - ) -> Optional[dict[str, Any]]: + ) -> dict[str, Any] | None: # Setup callback manager to collect Usage stats CallbackManager.set_callback_manager( platform_api_key=platform_api_key, llm=llm, **kwargs diff --git a/src/unstract/sdk/metrics_mixin.py b/src/unstract/sdk/metrics_mixin.py index 3adedd4..2b0d744 100644 --- a/src/unstract/sdk/metrics_mixin.py +++ b/src/unstract/sdk/metrics_mixin.py @@ -32,9 +32,7 @@ def __init__(self, run_id): decode_responses=True, ) except Exception as e: - logger.error( - "Failed to initialize Redis client" f" for run_id={run_id}: {e}" - ) + logger.error("Failed to initialize Redis client" f" for run_id={run_id}: {e}") self.redis_key = f"metrics:{self.run_id}:{self.op_id}" @@ -56,7 +54,6 @@ def collect_metrics(self) -> dict[str, Any]: Returns: dict: The calculated time taken and the associated run_id and op_id. """ - if self.redis_client is None: logger.error("Redis client is not initialized. Cannot collect metrics.") return {self.TIME_TAKEN_KEY: None} diff --git a/src/unstract/sdk/ocr.py b/src/unstract/sdk/ocr.py index 3ed3df9..249a977 100644 --- a/src/unstract/sdk/ocr.py +++ b/src/unstract/sdk/ocr.py @@ -1,8 +1,6 @@ from abc import ABCMeta -from typing import Optional from deprecated import deprecated - from unstract.sdk.adapter import ToolAdapter from unstract.sdk.adapters.constants import Common from unstract.sdk.adapters.ocr import adapters @@ -16,7 +14,7 @@ class OCR(metaclass=ABCMeta): def __init__( self, tool: BaseTool, - adapter_instance_id: Optional[str] = None, + adapter_instance_id: str | None = None, ): self._tool = tool self._ocr_adapters = adapters @@ -28,7 +26,7 @@ def _initialise(self, adapter_instance_id): if self._adapter_instance_id: self._ocr_instance: OCRAdapter = self._get_ocr() - def _get_ocr(self) -> Optional[OCRAdapter]: + def _get_ocr(self) -> OCRAdapter | None: try: if not self._adapter_instance_id: raise OCRError("Adapter instance ID not set. " "Initialisation failed") @@ -52,9 +50,7 @@ def _get_ocr(self) -> Optional[OCRAdapter]: ) return None - def process( - self, input_file_path: str, output_file_path: Optional[str] = None - ) -> str: + def process(self, input_file_path: str, output_file_path: str | None = None) -> str: return self._ocr_instance.process(input_file_path, output_file_path) @deprecated("Instantiate OCR and call process() instead") diff --git a/src/unstract/sdk/platform.py b/src/unstract/sdk/platform.py index 0fd183d..dd9c8d7 100644 --- a/src/unstract/sdk/platform.py +++ b/src/unstract/sdk/platform.py @@ -1,7 +1,6 @@ -from typing import Any, Optional +from typing import Any import requests - from unstract.sdk.constants import LogLevel, ToolEnv from unstract.sdk.helper import SdkHelper from unstract.sdk.tool.base import BaseTool @@ -20,8 +19,7 @@ def __init__( platform_host: str, platform_port: str, ) -> None: - """ - Args: + """Args: tool (AbstractTool): Instance of AbstractTool platform_host (str): Host of platform service platform_port (str): Port of platform service @@ -54,7 +52,7 @@ def __init__(self, tool: BaseTool, platform_host: str, platform_port: str): tool=tool, platform_host=platform_host, platform_port=platform_port ) - def get_platform_details(self) -> Optional[dict[str, Any]]: + def get_platform_details(self) -> dict[str, Any] | None: """Obtains platform details associated with the platform key. Currently helps fetch organization ID related to the key. diff --git a/src/unstract/sdk/prompt.py b/src/unstract/sdk/prompt.py index 55cffdc..ba8d6f7 100644 --- a/src/unstract/sdk/prompt.py +++ b/src/unstract/sdk/prompt.py @@ -1,9 +1,8 @@ import logging -from typing import Any, Optional +from typing import Any import requests from requests import ConnectionError, RequestException, Response - from unstract.sdk.constants import LogLevel, MimeType, PromptStudioKeys, ToolEnv from unstract.sdk.helper import SdkHelper from unstract.sdk.tool.base import BaseTool @@ -22,11 +21,10 @@ def __init__( prompt_port: str, is_public_call: bool = False, ) -> None: - """ - Args: - tool (AbstractTool): Instance of AbstractTool - prompt_host (str): Host of platform service - prompt_host (str): Port of platform service + """Args: + tool (AbstractTool): Instance of AbstractTool + prompt_host (str): Host of platform service + prompt_host (str): Port of platform service """ self.tool = tool self.base_url = SdkHelper.get_platform_base_url(prompt_host, prompt_port) @@ -38,8 +36,8 @@ def __init__( def answer_prompt( self, payload: dict[str, Any], - params: Optional[dict[str, str]] = None, - headers: Optional[dict[str, str]] = None, + params: dict[str, str] | None = None, + headers: dict[str, str] | None = None, ) -> dict[str, Any]: url_path = "answer-prompt" if self.is_public_call: @@ -47,13 +45,13 @@ def answer_prompt( return self._post_call( url_path=url_path, payload=payload, params=params, headers=headers ) - + @log_elapsed(operation="INDEX") def index( - self, - payload: dict[str, Any], - params: Optional[dict[str, str]] = None, - headers: Optional[dict[str, str]] = None, + self, + payload: dict[str, Any], + params: dict[str, str] | None = None, + headers: dict[str, str] | None = None, ) -> dict[str, Any]: url_path = "index" if self.is_public_call: @@ -64,13 +62,13 @@ def index( params=params, headers=headers, ) - + @log_elapsed(operation="EXTRACT") def extract( - self, - payload: dict[str, Any], - params: Optional[dict[str, str]] = None, - headers: Optional[dict[str, str]] = None, + self, + payload: dict[str, Any], + params: dict[str, str] | None = None, + headers: dict[str, str] | None = None, ) -> dict[str, Any]: url_path = "extract" if self.is_public_call: @@ -85,8 +83,8 @@ def extract( def single_pass_extraction( self, payload: dict[str, Any], - params: Optional[dict[str, str]] = None, - headers: Optional[dict[str, str]] = None, + params: dict[str, str] | None = None, + headers: dict[str, str] | None = None, ) -> dict[str, Any]: return self._post_call( url_path="single-pass-extraction", @@ -98,8 +96,8 @@ def single_pass_extraction( def summarize( self, payload: dict[str, Any], - params: Optional[dict[str, str]] = None, - headers: Optional[dict[str, str]] = None, + params: dict[str, str] | None = None, + headers: dict[str, str] | None = None, ) -> dict[str, Any]: return self._post_call( url_path="summarize", @@ -112,8 +110,8 @@ def _post_call( self, url_path: str, payload: dict[str, Any], - params: Optional[dict[str, str]] = None, - headers: Optional[dict[str, str]] = None, + params: dict[str, str] | None = None, + headers: dict[str, str] | None = None, ) -> dict[str, Any]: """Invokes and communicates to prompt service to fetch response for the prompt. @@ -192,7 +190,7 @@ def _stringify_and_stream_err(self, err: RequestException, msg: str) -> None: @staticmethod def get_exported_tool( tool: BaseTool, prompt_registry_id: str - ) -> Optional[dict[str, Any]]: + ) -> dict[str, Any] | None: """Get exported custom tool by the help of unstract DB tool. Args: diff --git a/src/unstract/sdk/tool/base.py b/src/unstract/sdk/tool/base.py index 774f14d..c247f32 100644 --- a/src/unstract/sdk/tool/base.py +++ b/src/unstract/sdk/tool/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from json import JSONDecodeError from pathlib import Path -from typing import Any, Union +from typing import Any from unstract.sdk.constants import ( Command, @@ -112,6 +112,7 @@ def handle_static_command(self, command: str) -> None: Args: command (str): The static command. + Returns: None """ @@ -130,9 +131,7 @@ def _get_file_from_data_dir(self, file_to_get: str, raise_err: bool = False) -> base_path = self.execution_dir file_path = base_path / file_to_get if raise_err and not self.workflow_filestorage.exists(path=file_path): - self.stream_error_and_exit( - f"{file_to_get} is missing in EXECUTION_DATA_DIR" - ) + self.stream_error_and_exit(f"{file_to_get} is missing in EXECUTION_DATA_DIR") return str(file_path) @@ -254,7 +253,7 @@ def update_exec_metadata(self, metadata: dict[str, Any]) -> None: self._write_exec_metadata(metadata=self._exec_metadata) - def write_tool_result(self, data: Union[str, dict[str, Any]]) -> None: + def write_tool_result(self, data: str | dict[str, Any]) -> None: """Helps write contents of the tool result into TOOL_DATA_DIR. Args: diff --git a/src/unstract/sdk/tool/mixin.py b/src/unstract/sdk/tool/mixin.py index 6c225c1..394a608 100644 --- a/src/unstract/sdk/tool/mixin.py +++ b/src/unstract/sdk/tool/mixin.py @@ -20,6 +20,7 @@ def spec( Args: spec_file (str): The path to the JSON schema file. The default is config/spec.json. + Returns: str: The JSON schema of the tool. """ @@ -35,6 +36,7 @@ def properties( Args: properties_file (str): The path to the properties file. The default is config/properties.json. + Returns: str: The properties of the tool. """ @@ -50,10 +52,10 @@ def variables( Args: variables_file (str): The path to the JSON schema file. The default is config/runtime_variables.json. + Returns: str: The JSON schema for the runtime variables. """ - try: return ToolUtils.load_json(variables_file, fs) # Allow runtime variables definition to be optional @@ -71,6 +73,7 @@ def icon( Args: icon_file (str): The path to the icon file. The default is config/icon.svg. + Returns: str: The icon of the tool. """ diff --git a/src/unstract/sdk/tool/parser.py b/src/unstract/sdk/tool/parser.py index 5ce1050..cf708b2 100644 --- a/src/unstract/sdk/tool/parser.py +++ b/src/unstract/sdk/tool/parser.py @@ -1,8 +1,6 @@ import argparse -from typing import Optional from dotenv import find_dotenv, load_dotenv - from unstract.sdk.constants import LogLevel @@ -45,7 +43,7 @@ def parse_args(args_to_parse: list[str]) -> argparse.Namespace: return parsed_args @staticmethod - def load_environment(path: Optional[str] = None) -> None: + def load_environment(path: str | None = None) -> None: """Loads env variables with python-dotenv. Args: diff --git a/src/unstract/sdk/tool/stream.py b/src/unstract/sdk/tool/stream.py index f4b0278..75e26aa 100644 --- a/src/unstract/sdk/tool/stream.py +++ b/src/unstract/sdk/tool/stream.py @@ -5,7 +5,6 @@ from typing import Any from deprecated import deprecated - from unstract.sdk.constants import Command, LogLevel, LogStage, ToolEnv from unstract.sdk.utils import Utils from unstract.sdk.utils.common_utils import UNSTRACT_TO_PY_LOG_LEVEL @@ -21,11 +20,10 @@ class StreamMixin: """ def __init__(self, log_level: LogLevel = LogLevel.INFO, **kwargs) -> None: - """ - Args: - log_level (LogLevel): The log level for filtering of log messages. - The default is INFO. - Allowed values are DEBUG, INFO, WARN, ERROR, and FATAL. + """Args: + log_level (LogLevel): The log level for filtering of log messages. + The default is INFO. + Allowed values are DEBUG, INFO, WARN, ERROR, and FATAL. """ self.log_level = log_level @@ -153,6 +151,7 @@ def stream_properties(properties: str) -> None: Args: properties (str): The properties of the tool. Typically returned by the properties() method. + Returns: None """ @@ -190,6 +189,7 @@ def stream_icon(icon: str) -> None: Args: icon (str): The icon of the tool. Typically returned by the icon() method. + Returns: None """ @@ -227,6 +227,7 @@ def stream_cost(cost: float, cost_units: str, **kwargs: Any) -> None: cost (float): The cost of the tool. cost_units (str): The cost units of the tool. **kwargs: Additional keyword arguments to include in the record. + Returns: None """ @@ -248,6 +249,7 @@ def stream_single_step_message(message: str, **kwargs: Any) -> None: Args: message (str): The single step message. **kwargs: Additional keyword arguments to include in the record. + Returns: None """ @@ -269,6 +271,7 @@ def stream_result(result: dict[Any, Any], **kwargs: Any) -> None: result (dict): The result of the tool. Refer to the Unstract protocol for the format of the result. **kwargs: Additional keyword arguments to include in the record. + Returns: None """ diff --git a/src/unstract/sdk/tool/validator.py b/src/unstract/sdk/tool/validator.py index 4f8edee..01a7289 100644 --- a/src/unstract/sdk/tool/validator.py +++ b/src/unstract/sdk/tool/validator.py @@ -4,7 +4,6 @@ from typing import Any from jsonschema import Draft202012Validator, ValidationError, validators - from unstract.sdk.constants import MetadataKey, PropKey from unstract.sdk.tool.base import BaseTool from unstract.sdk.tool.mime_types import EXT_MIME_MAP @@ -25,9 +24,7 @@ def extend_with_default(validator_class: Any) -> Any: """ validate_properties = validator_class.VALIDATORS["properties"] - def set_defaults( - validator: Any, properties: Any, instance: Any, schema: Any - ) -> Any: + def set_defaults(validator: Any, properties: Any, instance: Any, schema: Any) -> Any: for property_, subschema in properties.items(): if "default" in subschema: instance.setdefault(property_, subschema["default"]) @@ -87,9 +84,7 @@ def _validate_restrictions(self, input_file: Path) -> None: self._validate_file_size(input_file) self._validate_file_type(input_file) - def _validate_settings_and_fill_defaults( - self, tool_settings: dict[str, Any] - ) -> None: + def _validate_settings_and_fill_defaults(self, tool_settings: dict[str, Any]) -> None: """Validates and obtains settings for a tool. Validation is done against the tool's settings based @@ -143,9 +138,7 @@ def _parse_size_string(self, size_string: str) -> int: """ size_match = re.match(PropKey.FILE_SIZE_REGEX, size_string) if not size_match: - self.tool.stream_error_and_exit( - f"Invalid size string format: {size_string}" - ) + self.tool.stream_error_and_exit(f"Invalid size string format: {size_string}") size, unit = size_match.groups() size_in_bytes = int(size) diff --git a/src/unstract/sdk/utils/callback_manager.py b/src/unstract/sdk/utils/callback_manager.py index 03fcbd0..367ebb2 100644 --- a/src/unstract/sdk/utils/callback_manager.py +++ b/src/unstract/sdk/utils/callback_manager.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, Union +from collections.abc import Callable import tiktoken from deprecated import deprecated @@ -7,7 +7,6 @@ from llama_index.core.callbacks import TokenCountingHandler from llama_index.core.embeddings import BaseEmbedding from llama_index.core.llms import LLM - from unstract.sdk.utils.usage_handler import UsageHandler logger = logging.getLogger(__name__) @@ -37,7 +36,7 @@ class CallbackManager: @staticmethod def set_callback( platform_api_key: str, - model: Union[LLM, BaseEmbedding], + model: LLM | BaseEmbedding, kwargs, ) -> None: """Sets the standard callback manager for the llm. This is to be called @@ -57,13 +56,8 @@ def set_callback( embedding=embedding ) """ - # Nothing to do if callback manager is already set for the instance - if ( - model - and model.callback_manager - and len(model.callback_manager.handlers) > 0 - ): + if model and model.callback_manager and len(model.callback_manager.handlers) > 0: return model.callback_manager = CallbackManager.get_callback_manager( @@ -72,7 +66,7 @@ def set_callback( @staticmethod def get_callback_manager( - model: Union[LLM, BaseEmbedding], + model: LLM | BaseEmbedding, platform_api_key: str, kwargs, ) -> LlamaIndexCallbackManager: @@ -110,7 +104,7 @@ def get_callback_manager( @staticmethod def get_tokenizer( - model: Optional[Union[LLM, BaseEmbedding, None]], + model: LLM | BaseEmbedding | None, fallback_tokenizer: Callable[[str], list] = tiktoken.encoding_for_model( "gpt-3.5-turbo" ).encode, @@ -127,7 +121,6 @@ def get_tokenizer( Raises: OSError: If an error occurs while loading the tokenizer. """ - try: if isinstance(model, LLM): model_name: str = model.metadata.model_name @@ -146,8 +139,8 @@ def get_tokenizer( @deprecated("Use set_callback() instead") def set_callback_manager( platform_api_key: str, - llm: Optional[LLM] = None, - embedding: Optional[BaseEmbedding] = None, + llm: LLM | None = None, + embedding: BaseEmbedding | None = None, **kwargs, ) -> LlamaIndexCallbackManager: callback_manager: LlamaIndexCallbackManager = LlamaIndexCallbackManager() diff --git a/src/unstract/sdk/utils/indexing_utils.py b/src/unstract/sdk/utils/indexing_utils.py index 8f2f032..b0b6602 100644 --- a/src/unstract/sdk/utils/indexing_utils.py +++ b/src/unstract/sdk/utils/indexing_utils.py @@ -1,5 +1,4 @@ import json -from typing import Optional from unstract.sdk.adapter import ToolAdapter from unstract.sdk.file_storage import FileStorage, FileStorageProvider @@ -16,8 +15,8 @@ def generate_index_key( chunk_size: str, chunk_overlap: str, tool: BaseTool, - file_path: Optional[str] = None, - file_hash: Optional[str] = None, + file_path: str | None = None, + file_hash: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), ) -> str: """Generates a unique index key based on the provided configuration, diff --git a/src/unstract/sdk/utils/token_counter.py b/src/unstract/sdk/utils/token_counter.py index f2bdc88..b4b544e 100644 --- a/src/unstract/sdk/utils/token_counter.py +++ b/src/unstract/sdk/utils/token_counter.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any from llama_index.core.callbacks.schema import EventPayload from llama_index.core.llms import ChatResponse, CompletionResponse @@ -49,13 +49,11 @@ def get_llm_token_counts(payload: dict[str, Any]): @staticmethod def _get_tokens_from_response( - response: Union[CompletionResponse, ChatResponse, dict] + response: CompletionResponse | ChatResponse | dict, ) -> tuple[int, int]: """Get the token counts from a raw response.""" prompt_tokens, completion_tokens = 0, 0 - if isinstance(response, CompletionResponse) or isinstance( - response, ChatResponse - ): + if isinstance(response, CompletionResponse) or isinstance(response, ChatResponse): raw_response = response.raw if not isinstance(raw_response, dict): raw_response = dict(raw_response) diff --git a/src/unstract/sdk/utils/tool_utils.py b/src/unstract/sdk/utils/tool_utils.py index 7216b64..6d434c5 100644 --- a/src/unstract/sdk/utils/tool_utils.py +++ b/src/unstract/sdk/utils/tool_utils.py @@ -7,7 +7,6 @@ from typing import Any import magic - from unstract.sdk.exceptions import FileStorageError from unstract.sdk.file_storage import ( FileStorage, @@ -62,7 +61,6 @@ def get_hash_from_file( Returns: str: SHA256 hash of the file """ - # Adding the following DeprecationWarning manually as the package "deprecated" # does not support deprecation on static methods. warnings.warn( @@ -152,6 +150,7 @@ def get_file_size( fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), ) -> int: """Gets the file size in bytes for an input file. + Args: input_file (Path): Path object of the input file diff --git a/src/unstract/sdk/utils/usage_handler.py b/src/unstract/sdk/utils/usage_handler.py index 11853d5..68b21a4 100644 --- a/src/unstract/sdk/utils/usage_handler.py +++ b/src/unstract/sdk/utils/usage_handler.py @@ -1,11 +1,10 @@ -from typing import Any, Optional +from typing import Any from llama_index.core.callbacks import CBEventType, TokenCountingHandler from llama_index.core.callbacks.base_handler import BaseCallbackHandler from llama_index.core.callbacks.schema import EventPayload from llama_index.core.embeddings import BaseEmbedding from llama_index.core.llms import LLM - from unstract.sdk.audit import Audit from unstract.sdk.constants import LogLevel from unstract.sdk.tool.stream import StreamMixin @@ -35,11 +34,11 @@ class UsageHandler(StreamMixin, BaseCallbackHandler): def __init__( self, platform_api_key: str, - token_counter: Optional[TokenCountingHandler] = None, + token_counter: TokenCountingHandler | None = None, llm_model: LLM = None, embed_model: BaseEmbedding = None, - event_starts_to_ignore: Optional[list[CBEventType]] = None, - event_ends_to_ignore: Optional[list[CBEventType]] = None, + event_starts_to_ignore: list[CBEventType] | None = None, + event_ends_to_ignore: list[CBEventType] | None = None, verbose: bool = False, log_level: LogLevel = LogLevel.INFO, kwargs: dict[Any, Any] = None, @@ -56,20 +55,20 @@ def __init__( event_ends_to_ignore=event_ends_to_ignore or [], ) - def start_trace(self, trace_id: Optional[str] = None) -> None: + def start_trace(self, trace_id: str | None = None) -> None: return def end_trace( self, - trace_id: Optional[str] = None, - trace_map: Optional[dict[str, list[str]]] = None, + trace_id: str | None = None, + trace_map: dict[str, list[str]] | None = None, ) -> None: return def on_event_start( self, event_type: CBEventType, - payload: Optional[dict[str, Any]] = None, + payload: dict[str, Any] | None = None, event_id: str = "", parent_id: str = "", kwargs: dict[Any, Any] = None, @@ -79,7 +78,7 @@ def on_event_start( def on_event_end( self, event_type: CBEventType, - payload: Optional[dict[str, Any]] = None, + payload: dict[str, Any] | None = None, event_id: str = "", kwargs: dict[Any, Any] = None, ) -> None: diff --git a/src/unstract/sdk/vector_db.py b/src/unstract/sdk/vector_db.py index ebf3118..82ad03f 100644 --- a/src/unstract/sdk/vector_db.py +++ b/src/unstract/sdk/vector_db.py @@ -1,6 +1,6 @@ import logging from collections.abc import Sequence -from typing import Any, Optional, Union +from typing import Any from deprecated import deprecated from llama_index.core import StorageContext, VectorStoreIndex @@ -12,7 +12,6 @@ VectorStore, VectorStoreQueryResult, ) - from unstract.sdk.adapter import ToolAdapter from unstract.sdk.adapters.constants import Common from unstract.sdk.adapters.vectordb import adapters @@ -42,8 +41,8 @@ class VectorDB: def __init__( self, tool: BaseTool, - adapter_instance_id: Optional[str] = None, - embedding: Optional[Embedding] = None, + adapter_instance_id: str | None = None, + embedding: Embedding | None = None, ): self._tool = tool self._adapter_instance_id = adapter_instance_id @@ -52,14 +51,14 @@ def __init__( self._embedding_dimension = VectorDB.DEFAULT_EMBEDDING_DIMENSION self._initialise(embedding) - def _initialise(self, embedding: Optional[Embedding] = None): + def _initialise(self, embedding: Embedding | None = None): if embedding: self._embedding_instance = embedding._embedding_instance self._embedding_dimension = embedding._length if self._adapter_instance_id: - self._vector_db_instance: Union[ - BasePydanticVectorStore, VectorStore - ] = self._get_vector_db() + self._vector_db_instance: BasePydanticVectorStore | VectorStore = ( + self._get_vector_db() + ) def _get_org_id(self) -> str: platform_helper = PlatformHelper( @@ -75,7 +74,7 @@ def _get_org_id(self) -> str: account_id = platform_details.get("organization_id") return account_id - def _get_vector_db(self) -> Union[BasePydanticVectorStore, VectorStore]: + def _get_vector_db(self) -> BasePydanticVectorStore | VectorStore: """Gets an instance of LlamaIndex's VectorStore. Returns: @@ -83,9 +82,7 @@ def _get_vector_db(self) -> Union[BasePydanticVectorStore, VectorStore]: """ try: if not self._adapter_instance_id: - raise VectorDBError( - "Adapter instance ID not set. Initialisation failed" - ) + raise VectorDBError("Adapter instance ID not set. Initialisation failed") vector_db_config = ToolAdapter.get_adapter_config( self._tool, self._adapter_instance_id @@ -108,9 +105,9 @@ def _get_vector_db(self) -> Union[BasePydanticVectorStore, VectorStore]: org = self._get_org_id() vector_db_metadata[VectorDbConstants.VECTOR_DB_NAME] = org - vector_db_metadata[ - VectorDbConstants.EMBEDDING_DIMENSION - ] = self._embedding_dimension + vector_db_metadata[VectorDbConstants.EMBEDDING_DIMENSION] = ( + self._embedding_dimension + ) self.vector_db_adapter_class = vector_db_adapter(vector_db_metadata) return self.vector_db_adapter_class.get_vector_db_instance() @@ -151,7 +148,7 @@ def index_document( def get_vector_store_index_from_storage_context( self, documents: Sequence[Document], - storage_context: Optional[StorageContext] = None, + storage_context: StorageContext | None = None, show_progress: bool = False, callback_manager=None, **kwargs, @@ -216,7 +213,7 @@ def get_class_name(self) -> str: Args: NA - Returns: + Returns: Class name """ return self._vector_db_instance.class_name() @@ -224,7 +221,7 @@ def get_class_name(self) -> str: @deprecated("Use VectorDB instead of ToolVectorDB") def get_vector_db( self, adapter_instance_id: str, embedding_dimension: int - ) -> Union[BasePydanticVectorStore, VectorStore]: + ) -> BasePydanticVectorStore | VectorStore: if not self._vector_db_instance: self._adapter_instance_id = adapter_instance_id self._initialise() diff --git a/src/unstract/sdk/x2txt.py b/src/unstract/sdk/x2txt.py index 4bc6ba8..fadc6c4 100644 --- a/src/unstract/sdk/x2txt.py +++ b/src/unstract/sdk/x2txt.py @@ -1,10 +1,9 @@ import io from abc import ABCMeta -from typing import Any, Optional +from typing import Any import pdfplumber from deprecated import deprecated - from unstract.sdk.adapter import ToolAdapter from unstract.sdk.adapters.constants import Common from unstract.sdk.adapters.x2text import adapters @@ -26,7 +25,7 @@ class X2Text(metaclass=ABCMeta): def __init__( self, tool: BaseTool, - adapter_instance_id: Optional[str] = None, + adapter_instance_id: str | None = None, usage_kwargs: dict[Any, Any] = {}, ): self._tool = tool @@ -60,20 +59,18 @@ def _get_x2text(self) -> X2TextAdapter: ][Common.ADAPTER] x2text_metadata = x2text_config.get(Common.ADAPTER_METADATA) # Add x2text service host, port and platform_service_key - x2text_metadata[ + x2text_metadata[X2TextConstants.X2TEXT_HOST] = self._tool.get_env_or_die( X2TextConstants.X2TEXT_HOST - ] = self._tool.get_env_or_die(X2TextConstants.X2TEXT_HOST) - x2text_metadata[ + ) + x2text_metadata[X2TextConstants.X2TEXT_PORT] = self._tool.get_env_or_die( X2TextConstants.X2TEXT_PORT - ] = self._tool.get_env_or_die(X2TextConstants.X2TEXT_PORT) - - if not SdkHelper.is_public_adapter( - adapter_id=self._adapter_instance_id - ): - x2text_metadata[ - X2TextConstants.PLATFORM_SERVICE_API_KEY - ] = self._tool.get_env_or_die( - X2TextConstants.PLATFORM_SERVICE_API_KEY + ) + + if not SdkHelper.is_public_adapter(adapter_id=self._adapter_instance_id): + x2text_metadata[X2TextConstants.PLATFORM_SERVICE_API_KEY] = ( + self._tool.get_env_or_die( + X2TextConstants.PLATFORM_SERVICE_API_KEY + ) ) self._x2text_instance = x2text_adapter(x2text_metadata) @@ -90,7 +87,7 @@ def _get_x2text(self) -> X2TextAdapter: def process( self, input_file_path: str, - output_file_path: Optional[str] = None, + output_file_path: str | None = None, fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), **kwargs: dict[Any, Any], ) -> TextExtractionResult: diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 890f847..d5eae03 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -5,7 +5,6 @@ from dotenv import load_dotenv from llama_index.core.embeddings import BaseEmbedding from parameterized import parameterized - from unstract.sdk.embedding import ToolEmbedding from unstract.sdk.tool.base import BaseTool @@ -34,9 +33,7 @@ def run_embedding_test(self, adapter_instance_id): embed_model = embedding.get_embedding(adapter_instance_id) self.assertIsNotNone(embed_model) self.assertIsInstance(embed_model, BaseEmbedding) - response = embed_model._get_text_embedding( - ToolEmbeddingTest.TEST_SNIPPET - ) + response = embed_model._get_text_embedding(ToolEmbeddingTest.TEST_SNIPPET) self.assertIsNotNone(response) @parameterized.expand(get_test_values("EMBEDDING_TEST_VALUES")) diff --git a/tests/test_index.py b/tests/test_index.py index 4e60b28..9aa05f5 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -2,12 +2,11 @@ import logging import os import unittest -from typing import Any, Optional +from typing import Any from unittest.mock import Mock, patch from dotenv import load_dotenv from parameterized import parameterized - from unstract.sdk.index import Index from unstract.sdk.tool.base import BaseTool @@ -101,8 +100,8 @@ def test_generate_index_key( x2text: str, chunk_size: str, chunk_overlap: str, - file_path: Optional[str] = None, - file_hash: Optional[str] = None, + file_path: str | None = None, + file_hash: str | None = None, ): expected = "77843eb8d9e30ad56bfcb018c2633fa32feef2f0c09762b6b820c75664b64c1b" index = Index(tool=self.tool) diff --git a/tests/test_llm.py b/tests/test_llm.py index 17d2ecb..99bce15 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -7,7 +7,6 @@ from dotenv import load_dotenv from parameterized import parameterized from unstract.adapters.llm.helper import LLMHelper - from unstract.sdk.llm import ToolLLM from unstract.sdk.tool.base import BaseTool diff --git a/tests/test_vector_db.py b/tests/test_vector_db.py index 345a716..9ab8fc9 100644 --- a/tests/test_vector_db.py +++ b/tests/test_vector_db.py @@ -11,7 +11,6 @@ ) from parameterized import parameterized from unstract.adapters.vectordb.helper import VectorDBHelper - from unstract.sdk.tool.base import BaseTool from unstract.sdk.vector_db import ToolVectorDB @@ -46,9 +45,7 @@ def test_get_vector_db(self, adapter_instance_id: str) -> None: adapter_instance_id, mock_embedding.embed_dim ) self.assertIsNotNone(vector_store) - self.assertIsInstance( - vector_store, (BasePydanticVectorStore, VectorStore) - ) + self.assertIsInstance(vector_store, (BasePydanticVectorStore, VectorStore)) result = VectorDBHelper.test_vector_db_instance(vector_store) self.assertEqual(result, True)