Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
688 changes: 245 additions & 443 deletions pdm.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
"python-magic~=0.4.27",
"python-dotenv==1.0.0",
# LLM Triad
"unstract-adapters~=0.2.1",
"unstract-adapters~=0.2.2",
"llama-index==0.9.28",
"tiktoken~=0.4.0",
"transformers==4.37.0",
Expand Down
2 changes: 1 addition & 1 deletion src/unstract/sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.11.1"
__version__ = "0.11.2"


def get_sdk_version():
Expand Down
38 changes: 38 additions & 0 deletions src/unstract/sdk/ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from abc import ABCMeta
from typing import Optional

from unstract.adapters.constants import Common
from unstract.adapters.ocr import adapters
from unstract.adapters.ocr.ocr_adapter import OCRAdapter

from unstract.sdk.adapters import ToolAdapter
from unstract.sdk.constants import LogLevel
from unstract.sdk.tool.base import BaseTool


class OCR(metaclass=ABCMeta):
def __init__(self, tool: BaseTool):
self.tool = tool
self.ocr_adapters = adapters

def get_ocr(self, adapter_instance_id: str) -> Optional[OCRAdapter]:
try:
ocr_config = ToolAdapter.get_adapter_config(
self.tool, adapter_instance_id
)
ocr_adapter_id = ocr_config.get(Common.ADAPTER_ID)
if ocr_adapter_id in self.ocr_adapters:
ocr_adapter = self.ocr_adapters[ocr_adapter_id][
Common.METADATA
][Common.ADAPTER]
ocr_metadata = ocr_config.get(Common.ADAPTER_METADATA)
ocr_adapter_class = ocr_adapter(ocr_metadata)

return ocr_adapter_class

except Exception as e:
self.tool.stream_log(
log=f"Unable to get OCR adapter {adapter_instance_id}: {e}",
level=LogLevel.ERROR,
)
return None
3 changes: 2 additions & 1 deletion src/unstract/sdk/x2txt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABCMeta
from typing import Optional

from unstract.adapters.constants import Common
from unstract.adapters.x2text import adapters
Expand All @@ -15,7 +16,7 @@ def __init__(self, tool: BaseTool):
self.tool = tool
self.x2text_adapters = adapters

def get_x2text(self, adapter_instance_id: str) -> X2TextAdapter:
def get_x2text(self, adapter_instance_id: str) -> Optional[X2TextAdapter]:
try:
x2text_config = ToolAdapter.get_adapter_config(
self.tool, adapter_instance_id
Expand Down
2 changes: 2 additions & 0 deletions tests/sample.env
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ X2TEXT_PORT=3004
LLM_TEST_VALUES=["", "", ""]
EMBEDDING_TEST_VALUES=["", "", ""]
VECTOR_DB_TEST_VALUES=["", "", ""]
OCR_TEST_VALUES=["", ""]
X2TEXT_TEST_VALUES=["", ""]

65 changes: 65 additions & 0 deletions tests/test_ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import json
import logging
import os
import unittest
from typing import Any

from dotenv import load_dotenv
from parameterized import parameterized

from unstract.sdk.ocr import OCR
from unstract.sdk.tool.base import BaseTool

load_dotenv()

logger = logging.getLogger(__name__)


def get_test_values(env_key: str) -> list[str]:
values = json.loads(os.environ.get(env_key))
return values


def get_env_value(env_key: str) -> str:
value = os.environ.get(env_key)
return value


class ToolOCRTest(unittest.TestCase):
class MockTool(BaseTool):
def run(
self,
params: dict[str, Any] = {},
settings: dict[str, Any] = {},
workflow_id: str = "",
) -> None:
pass

@classmethod
def setUpClass(cls):
cls.tool = cls.MockTool()

@parameterized.expand(get_test_values("OCR_TEST_VALUES"))
def test_get_ocr(self, adapter_instance_id):
tool_ocr = OCR(tool=self.tool)
ocr = tool_ocr.get_ocr(adapter_instance_id)
result = ocr.test_connection()
self.assertTrue(result)
input_file = get_env_value("INPUT_FILE_PATH")
output_file = get_env_value("OUTPUT_FILE_PATH")
if os.path.isfile(output_file):
os.remove(output_file)
output = ocr.process(input_file, output_file)
file_size = os.path.getsize(output_file)
self.assertGreater(file_size, 0)
if os.path.isfile(output_file):
os.remove(output_file)
with open(output_file, "w", encoding="utf-8") as f:
f.write(output)
f.close()
file_size = os.path.getsize(output_file)
self.assertGreater(file_size, 0)


if __name__ == "__main__":
unittest.main()