Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add the ability to specify a custom OCR besides the ones natively supported #2462

Merged
merged 14 commits into from
Jan 31, 2024
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ jobs:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
PINECONE_API_KEY: ${{secrets.PINECONE_API_KEY}}
TABLE_OCR: "tesseract"
OCR_AGENT: "tesseract"
OCR_AGENT: "unstructured.partition.utils.ocr_models.tesseract_ocr.OCRAgentTesseract"
CI: "true"
run: |
source .venv/bin/activate
Expand Down Expand Up @@ -380,7 +380,7 @@ jobs:
AZURE_DEST_CONNECTION_STR: ${{ secrets.AZURE_DEST_CONNECTION_STR }}
PINECONE_API_KEY: ${{secrets.PINECONE_API_KEY}}
TABLE_OCR: "tesseract"
OCR_AGENT: "tesseract"
OCR_AGENT: "unstructured.partition.utils.ocr_models.tesseract_ocr.OCRAgentTesseract"
CI: "true"
run: |
source .venv/bin/activate
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ingest-test-fixtures-update-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ jobs:
AZURE_SEARCH_API_KEY: ${{ secrets.AZURE_SEARCH_API_KEY }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
TABLE_OCR: "tesseract"
OCR_AGENT: "tesseract"
OCR_AGENT: "unstructured.partition.utils.ocr_models.tesseract_ocr.OCRAgentTesseract"
OVERWRITE_FIXTURES: "true"
CI: "true"
run: |
Expand Down
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## 0.12.4-dev2
## 0.12.4-dev3

### Enhancements

Expand All @@ -7,6 +7,7 @@
### Features

* **Add .heic file partitioning** .heic image files were previously unsupported and are now supported though partition_image()
* **Add the ability to specify an alternate OCR** implementation by implementing an `OCRAgent` interface and specify it using `OCR_AGENT` environment variable.

### Fixes

Expand Down
44 changes: 25 additions & 19 deletions test_unstructured/partition/pdf_image/test_ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,17 @@
from unstructured.partition.pdf_image import ocr
from unstructured.partition.pdf_image.ocr import pad_element_bboxes
from unstructured.partition.utils.constants import (
OCR_AGENT_PADDLE,
OCR_AGENT_TESSERACT,
Source,
)
from unstructured.partition.utils.ocr_models import paddle_ocr
from unstructured.partition.utils.ocr_models.ocr_interface import (
get_elements_from_ocr_regions,
merge_text_regions,
)
from unstructured.partition.utils.ocr_models.paddle_ocr import OCRAgentPaddle
from unstructured.partition.utils.ocr_models.tesseract_ocr import (
OCRAgentTesseract,
zoom_image,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -83,10 +89,10 @@ def test_get_ocr_layout_from_image_tesseract(monkeypatch):

image = Image.new("RGB", (100, 100))

ocr_layout = ocr.get_ocr_layout_from_image(
ocr_agent = OCRAgentTesseract()
ocr_layout = ocr_agent.get_layout_from_image(
image,
ocr_languages="eng",
ocr_agent=OCR_AGENT_TESSERACT,
)

expected_layout = [
Expand Down Expand Up @@ -127,7 +133,7 @@ def mock_ocr(*args, **kwargs):
]


def monkeypatch_load_agent():
def monkeypatch_load_agent(language: str):
class MockAgent:
def __init__(self):
self.ocr = mock_ocr
Expand All @@ -137,17 +143,16 @@ def __init__(self):

def test_get_ocr_layout_from_image_paddle(monkeypatch):
monkeypatch.setattr(
paddle_ocr,
OCRAgentPaddle,
"load_agent",
monkeypatch_load_agent,
)

image = Image.new("RGB", (100, 100))

ocr_layout = ocr.get_ocr_layout_from_image(
ocr_layout = OCRAgentPaddle().get_layout_from_image(
image,
ocr_languages="eng",
ocr_agent=OCR_AGENT_PADDLE,
)

expected_layout = [
Expand All @@ -167,28 +172,28 @@ def test_get_ocr_text_from_image_tesseract(monkeypatch):
)
image = Image.new("RGB", (100, 100))

ocr_text = ocr.get_ocr_text_from_image(
ocr_agent = OCRAgentTesseract()
ocr_text = ocr_agent.get_text_from_image(
image,
ocr_languages="eng",
ocr_agent=OCR_AGENT_TESSERACT,
)

assert ocr_text == "Hello World"


def test_get_ocr_text_from_image_paddle(monkeypatch):
monkeypatch.setattr(
paddle_ocr,
OCRAgentPaddle,
"load_agent",
monkeypatch_load_agent,
)

image = Image.new("RGB", (100, 100))

ocr_text = ocr.get_ocr_text_from_image(
ocr_agent = OCRAgentPaddle()
ocr_text = ocr_agent.get_text_from_image(
image,
ocr_languages="eng",
ocr_agent=OCR_AGENT_PADDLE,
)

assert ocr_text == "Hello\n\nWorld\n\n!"
Expand Down Expand Up @@ -239,7 +244,7 @@ def test_merge_text_regions(mock_embedded_text_regions):
text="LayoutParser: A Unified Toolkit for Deep Learning Based Document Image",
)

merged_text_region = ocr.merge_text_regions(mock_embedded_text_regions)
merged_text_region = merge_text_regions(mock_embedded_text_regions)
assert merged_text_region == expected


Expand All @@ -255,15 +260,15 @@ def test_get_elements_from_ocr_regions(mock_embedded_text_regions):
),
]

elements = ocr.get_elements_from_ocr_regions(mock_embedded_text_regions)
elements = get_elements_from_ocr_regions(mock_embedded_text_regions)
assert elements == expected


@pytest.mark.parametrize("zoom", [1, 0.1, 5, -1, 0])
def test_zoom_image(zoom):
image = Image.new("RGB", (100, 100))
width, height = image.size
new_image = ocr.zoom_image(image, zoom)
new_image = zoom_image(image, zoom)
new_w, new_h = new_image.size
if zoom <= 0:
zoom = 1
Expand Down Expand Up @@ -448,7 +453,7 @@ def mock_ocr_layout():


def test_get_table_tokens(mock_ocr_layout):
with patch.object(ocr, "get_ocr_layout_from_image", return_value=mock_ocr_layout):
with patch.object(OCRAgentTesseract, "get_layout_from_image", return_value=mock_ocr_layout):
table_tokens = ocr.get_table_tokens(image=None)
expected_tokens = [
{
Expand Down Expand Up @@ -488,8 +493,9 @@ def test_auto_zoom_not_exceed_tesseract_limit(monkeypatch):
)

image = Image.new("RGB", (1000, 1000))
ocr_agent = OCRAgentTesseract()
# tests that the code can run instead of oom and OCR results make sense
assert [region.text for region in ocr.get_ocr_layout_tesseract(image)] == [
assert [region.text for region in ocr_agent.get_layout_from_image(image)] == [
"Hello",
"World",
"!",
Expand Down
2 changes: 1 addition & 1 deletion unstructured/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.12.4-dev2" # pragma: no cover
__version__ = "0.12.4-dev3" # pragma: no cover
7 changes: 2 additions & 5 deletions unstructured/partition/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@
from unstructured.partition.strategies import determine_pdf_or_image_strategy, validate_strategy
from unstructured.partition.text import element_from_text
from unstructured.partition.utils.constants import (
OCR_AGENT_TESSERACT,
SORT_MODE_BASIC,
SORT_MODE_DONT,
SORT_MODE_XY_CUT,
Expand Down Expand Up @@ -929,21 +928,19 @@ def _partition_pdf_or_image_with_ocr_from_image(
"""Extract `unstructured` elements from an image using OCR and perform partitioning."""

from unstructured.partition.pdf_image.ocr import (
get_layout_elements_from_ocr,
get_ocr_agent,
)

ocr_agent = get_ocr_agent()
ocr_languages = prepare_languages_for_tesseract(languages)

# NOTE(christine): `unstructured_pytesseract.image_to_string()` returns sorted text
if ocr_agent == OCR_AGENT_TESSERACT:
if ocr_agent.is_text_sorted():
sort_mode = SORT_MODE_DONT

ocr_data = get_layout_elements_from_ocr(
ocr_data = ocr_agent.get_layout_elements_from_image(
image=image,
ocr_languages=ocr_languages,
ocr_agent=ocr_agent,
)

metadata = ElementMetadata(
Expand Down