diff --git a/CHANGELOG.md b/CHANGELOG.md index 94e861f1..d1add3a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.2.13-dev0 + +* Change OCR logic to be aware of PDF image elements + ## 0.2.12 * Fix for processing RGBA images diff --git a/requirements/base.txt b/requirements/base.txt index 92eb5756..d488f683 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -12,7 +12,7 @@ certifi==2022.12.7 # via requests cffi==1.15.1 # via cryptography -charset-normalizer==3.0.1 +charset-normalizer==3.1.0 # via # pdfminer-six # requests @@ -22,25 +22,26 @@ coloredlogs==15.0.1 # via onnxruntime contourpy==1.0.7 # via matplotlib -cryptography==39.0.1 +cryptography==39.0.2 # via pdfminer-six cycler==0.11.0 # via matplotlib effdet==0.3.0 # via layoutparser -fastapi==0.92.0 +fastapi==0.95.0 # via unstructured-inference (setup.py) -filelock==3.9.0 +filelock==3.10.0 # via # huggingface-hub + # torch # transformers -flatbuffers==23.1.21 +flatbuffers==23.3.3 # via onnxruntime -fonttools==4.38.0 +fonttools==4.39.2 # via matplotlib h11==0.14.0 # via uvicorn -huggingface-hub==0.12.1 +huggingface-hub==0.13.3 # via # timm # transformers @@ -55,14 +56,20 @@ importlib-resources==5.12.0 # via matplotlib iopath==0.1.10 # via layoutparser +jinja2==3.1.2 + # via torch kiwisolver==1.4.4 # via matplotlib layoutparser[layoutmodels,tesseract]==0.3.4 # via unstructured-inference (setup.py) -matplotlib==3.7.0 +markupsafe==2.1.2 + # via jinja2 +matplotlib==3.7.1 # via pycocotools -mpmath==1.2.1 +mpmath==1.3.0 # via sympy +networkx==3.0 + # via torch numpy==1.24.2 # via # contourpy @@ -92,7 +99,7 @@ packaging==23.0 # transformers pandas==1.5.3 # via layoutparser -pdf2image==1.16.2 +pdf2image==1.16.3 # via layoutparser pdfminer-six==20221105 # via pdfplumber @@ -108,13 +115,13 @@ pillow==9.4.0 # torchvision portalocker==2.7.0 # via iopath -protobuf==4.22.0 +protobuf==4.22.1 # via onnxruntime pycocotools==2.0.6 # via effdet pycparser==2.21 # via cffi -pydantic==1.10.5 +pydantic==1.10.6 # via fastapi pyparsing==3.0.9 # via matplotlib @@ -142,39 +149,39 @@ requests==2.28.2 # huggingface-hub # torchvision # transformers -scipy==1.10.0 +scipy==1.10.1 # via layoutparser six==1.16.0 - # via - # python-dateutil - # python-multipart + # via python-dateutil sniffio==1.3.0 # via anyio -starlette==0.25.0 +starlette==0.26.1 # via fastapi sympy==1.11.1 - # via onnxruntime + # via + # onnxruntime + # torch timm==0.6.12 # via effdet tokenizers==0.13.2 # via transformers -torch==1.13.1 +torch==2.0.0 # via # effdet # layoutparser # timm # torchvision -torchvision==0.14.1 +torchvision==0.15.1 # via # effdet # layoutparser # timm -tqdm==4.64.1 +tqdm==4.65.0 # via # huggingface-hub # iopath # transformers -transformers==4.26.1 +transformers==4.27.2 # via unstructured-inference (setup.py) typing-extensions==4.5.0 # via @@ -183,12 +190,11 @@ typing-extensions==4.5.0 # pydantic # starlette # torch - # torchvision -urllib3==1.26.14 +urllib3==1.26.15 # via requests -uvicorn==0.20.0 +uvicorn==0.21.1 # via unstructured-inference (setup.py) wand==0.6.11 # via pdfplumber -zipp==3.14.0 +zipp==3.15.0 # via importlib-resources diff --git a/requirements/dev.txt b/requirements/dev.txt index 22577d67..3f580044 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -21,7 +21,7 @@ attrs==22.2.0 # via jsonschema backcall==0.2.0 # via ipython -beautifulsoup4==4.11.2 +beautifulsoup4==4.12.0 # via nbconvert bleach==6.0.0 # via nbconvert @@ -41,7 +41,7 @@ defusedxml==0.7.1 # via nbconvert executing==1.2.0 # via stack-data -fastjsonschema==2.16.2 +fastjsonschema==2.16.3 # via nbformat fqdn==1.5.1 # via jsonschema @@ -49,9 +49,14 @@ idna==3.4 # via # anyio # jsonschema -ipykernel==6.21.2 +importlib-metadata==6.1.0 + # via + # jupyter-client + # nbconvert +importlib-resources==5.12.0 + # via jsonschema +ipykernel==6.22.0 # via - # ipywidgets # jupyter # jupyter-console # nbclassic @@ -68,7 +73,7 @@ ipython-genutils==0.2.0 # nbclassic # notebook # qtconsole -ipywidgets==8.0.4 +ipywidgets==8.0.5 # via jupyter isoduration==20.11.0 # via jsonschema @@ -87,8 +92,8 @@ jsonschema[format-nongpl]==4.17.3 # jupyter-events # nbformat jupyter==1.0.0 - # via -r dev.in -jupyter-client==8.0.3 + # via -r requirements/dev.in +jupyter-client==8.1.0 # via # ipykernel # jupyter-console @@ -97,9 +102,9 @@ jupyter-client==8.0.3 # nbclient # notebook # qtconsole -jupyter-console==6.5.1 +jupyter-console==6.6.3 # via jupyter -jupyter-core==5.2.0 +jupyter-core==5.3.0 # via # ipykernel # jupyter-client @@ -113,7 +118,7 @@ jupyter-core==5.2.0 # qtconsole jupyter-events==0.6.3 # via jupyter-server -jupyter-server==2.3.0 +jupyter-server==2.5.0 # via # nbclassic # notebook-shim @@ -121,7 +126,7 @@ jupyter-server-terminals==0.4.4 # via jupyter-server jupyterlab-pygments==0.2.2 # via nbconvert -jupyterlab-widgets==3.0.5 +jupyterlab-widgets==3.0.6 # via ipywidgets markupsafe==2.1.2 # via @@ -133,17 +138,17 @@ matplotlib-inline==0.1.6 # ipython mistune==2.0.5 # via nbconvert -nbclassic==0.5.1 +nbclassic==0.5.3 # via notebook nbclient==0.7.2 # via nbconvert -nbconvert==7.2.9 +nbconvert==7.2.10 # via # jupyter # jupyter-server # nbclassic # notebook -nbformat==5.7.3 +nbformat==5.8.0 # via # jupyter-server # nbclassic @@ -155,7 +160,7 @@ nest-asyncio==1.5.6 # ipykernel # nbclassic # notebook -notebook==6.5.2 +notebook==6.5.3 # via jupyter notebook-shim==0.2.2 # via nbclassic @@ -165,6 +170,7 @@ packaging==23.0 # ipykernel # jupyter-server # nbconvert + # qtconsole # qtpy pandocfilters==1.5.0 # via nbconvert @@ -174,16 +180,18 @@ pexpect==4.8.0 # via ipython pickleshare==0.7.5 # via ipython -pip-tools==6.12.2 - # via -r dev.in -platformdirs==3.0.0 +pip-tools==6.12.3 + # via -r requirements/dev.in +pkgutil-resolve-name==1.3.10 + # via jsonschema +platformdirs==3.1.1 # via jupyter-core prometheus-client==0.16.0 # via # jupyter-server # nbclassic # notebook -prompt-toolkit==3.0.36 +prompt-toolkit==3.0.38 # via # ipython # jupyter-console @@ -211,11 +219,11 @@ python-dateutil==2.8.2 # via # arrow # jupyter-client -python-json-logger==2.0.6 +python-json-logger==2.0.7 # via jupyter-events pyyaml==6.0 # via jupyter-events -pyzmq==25.0.0 +pyzmq==25.0.2 # via # ipykernel # jupyter-client @@ -224,7 +232,7 @@ pyzmq==25.0.0 # nbclassic # notebook # qtconsole -qtconsole==5.4.0 +qtconsole==5.4.1 # via jupyter qtpy==2.3.0 # via qtconsole @@ -299,10 +307,14 @@ webencodings==0.5.1 # tinycss2 websocket-client==1.5.1 # via jupyter-server -wheel==0.38.4 +wheel==0.40.0 # via pip-tools -widgetsnbextension==4.0.5 +widgetsnbextension==4.0.6 # via ipywidgets +zipp==3.15.0 + # via + # importlib-metadata + # importlib-resources # The following packages are considered to be unsafe in a requirements file: # pip diff --git a/requirements/test.in b/requirements/test.in index f57d2de7..99ee65d3 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -8,6 +8,7 @@ click>=8.1 # starlette even though it's required for TestClient httpx flake8 +flake8-docstrings mypy pytest-cov pdf2image>=1.16.2 diff --git a/requirements/test.txt b/requirements/test.txt index 0f710414..367afca5 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -17,21 +17,25 @@ certifi==2022.12.7 # httpcore # httpx # requests -charset-normalizer==3.0.1 +charset-normalizer==3.1.0 # via requests click==8.1.3 # via # -r requirements/test.in # black -coverage[toml]==7.1.0 +coverage[toml]==7.2.2 # via # -r requirements/test.in # pytest-cov -exceptiongroup==1.1.0 +exceptiongroup==1.1.1 # via pytest -filelock==3.9.0 +filelock==3.10.0 # via huggingface-hub flake8==6.0.0 + # via + # -r requirements/test.in + # flake8-docstrings +flake8-docstrings==1.7.0 # via -r requirements/test.in h11==0.14.0 # via httpcore @@ -39,7 +43,7 @@ httpcore==0.16.3 # via httpx httpx==0.23.3 # via -r requirements/test.in -huggingface-hub==0.12.1 +huggingface-hub==0.13.3 # via -r requirements/test.in idna==3.4 # via @@ -49,7 +53,7 @@ idna==3.4 # yarl iniconfig==2.0.0 # via pytest -label-studio-sdk==0.0.19 +label-studio-sdk==0.0.21 # via -r requirements/test.in label-studio-tools==0.0.2 # via label-studio-sdk @@ -61,7 +65,7 @@ mccabe==0.7.0 # via flake8 multidict==6.0.4 # via yarl -mypy==1.0.1 +mypy==1.1.1 # via -r requirements/test.in mypy-extensions==1.0.0 # via @@ -72,23 +76,25 @@ packaging==23.0 # black # huggingface-hub # pytest -pathspec==0.11.0 +pathspec==0.11.1 # via black -pdf2image==1.16.2 +pdf2image==1.16.3 # via -r requirements/test.in pillow==9.4.0 # via pdf2image -platformdirs==3.0.0 +platformdirs==3.1.1 # via black pluggy==1.0.0 # via pytest pycodestyle==2.10.0 # via flake8 -pydantic==1.10.5 +pydantic==1.10.6 # via label-studio-sdk +pydocstyle==6.3.0 + # via flake8-docstrings pyflakes==3.0.1 # via flake8 -pytest==7.2.1 +pytest==7.2.2 # via pytest-cov pytest-cov==4.0.0 # via -r requirements/test.in @@ -109,13 +115,15 @@ sniffio==1.3.0 # anyio # httpcore # httpx +snowballstemmer==2.2.0 + # via pydocstyle tomli==2.0.1 # via # black # coverage # mypy # pytest -tqdm==4.64.1 +tqdm==4.65.0 # via huggingface-hub typing-extensions==4.5.0 # via @@ -123,11 +131,11 @@ typing-extensions==4.5.0 # huggingface-hub # mypy # pydantic -urllib3==1.26.14 +urllib3==1.26.15 # via requests vcrpy==4.2.1 # via -r requirements/test.in -wrapt==1.14.1 +wrapt==1.15.0 # via vcrpy yarl==1.8.2 # via vcrpy diff --git a/sample-docs/loremipsum-flat.pdf b/sample-docs/loremipsum-flat.pdf new file mode 100644 index 00000000..3742a74a Binary files /dev/null and b/sample-docs/loremipsum-flat.pdf differ diff --git a/setup.cfg b/setup.cfg index a06a3629..78525135 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,6 +3,9 @@ license_files = LICENSE.md [flake8] max-line-length = 100 +ignore = D100, D101, D104, D105, D107, D2, D4 +per-file-ignores = + test_*/**: D [tool:pytest] filterwarnings = diff --git a/test_unstructured_inference/inference/test_layout.py b/test_unstructured_inference/inference/test_layout.py index 251ee4cb..6e322a39 100644 --- a/test_unstructured_inference/inference/test_layout.py +++ b/test_unstructured_inference/inference/test_layout.py @@ -1,16 +1,14 @@ from functools import partial +from itertools import product import pytest import tempfile -from unittest.mock import patch, mock_open, Mock +from unittest.mock import patch, mock_open -import layoutparser as lp -from layoutparser.elements import Layout, Rectangle, TextBlock import numpy as np from PIL import Image import unstructured_inference.inference.layout as layout import unstructured_inference.models.base as models - import unstructured_inference.models.detectron2 as detectron2 import unstructured_inference.models.tesseract as tesseract @@ -22,17 +20,15 @@ def mock_image(): @pytest.fixture def mock_page_layout(): - text_rectangle = Rectangle(2, 4, 6, 8) - text_block = TextBlock(text_rectangle, text="A very repetitive narrative. " * 10, type="Text") + text_block = layout.TextRegion(2, 4, 6, 8, text="A very repetitive narrative. " * 10) - title_rectangle = Rectangle(1, 2, 3, 4) - title_block = TextBlock(title_rectangle, text="A Catchy Title", type="Title") + title_block = layout.TextRegion(1, 2, 3, 4, text="A Catchy Title") - return Layout([text_block, title_block]) + return [text_block, title_block] def test_pdf_page_converts_images_to_array(mock_image): - page = layout.PageLayout(number=0, image=mock_image, layout=Layout()) + page = layout.PageLayout(number=0, image=mock_image, layout=[]) assert page.image_array is None image_array = page._get_image_array() @@ -50,9 +46,8 @@ def detect(self, *args): monkeypatch.setattr(tesseract, "ocr_agent", MockOCRAgent) monkeypatch.setattr(tesseract, "is_pytesseract_available", lambda *args: True) - image = np.random.randint(12, 24, (40, 40)) - rectangle = Rectangle(1, 2, 3, 4) - text_block = TextBlock(rectangle, text=None) + image = Image.fromarray(np.random.randint(12, 24, (40, 40)), mode="RGB") + text_block = layout.TextRegion(1, 2, 3, 4, text=None) assert layout.ocr(text_block, image=image) == mock_text @@ -95,28 +90,27 @@ def join(self): def test_get_page_elements_with_ocr(monkeypatch): - rectangle = Rectangle(2, 4, 6, 8) - text_block = TextBlock(rectangle, text=None, type="Title") - doc_layout = Layout([text_block]) + text_block = layout.TextRegion(2, 4, 6, 8, text=None) + image_block = layout.ImageTextRegion(8, 14, 16, 18) + doc_layout = [text_block, image_block] monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True) monkeypatch.setattr(layout, "ocr", lambda *args: "An Even Catchier Title") image = Image.fromarray(np.random.randint(12, 14, size=(40, 10, 3)), mode="RGB") - print(layout.ocr(text_block, image)) page = layout.PageLayout( number=0, image=image, layout=doc_layout, model=MockLayoutModel(doc_layout) ) page.get_elements() - assert str(page) == "An Even Catchier Title" + assert str(page) == "\n\nAn Even Catchier Title" def test_read_pdf(monkeypatch, mock_page_layout): image = np.random.randint(12, 24, (40, 40)) images = [image, image] - layouts = Layout([mock_page_layout, mock_page_layout]) + layouts = [mock_page_layout, mock_page_layout] monkeypatch.setattr( models, "UnstructuredDetectronModel", partial(MockLayoutModel, layout=mock_page_layout) @@ -180,7 +174,7 @@ def tolist(self): return [1, 2, 3, 4] -class MockTextBlock(lp.TextBlock): +class MockTextRegion(layout.TextRegion): def __init__(self, type=None, text=None, ocr_text=None): self.type = type self.text = text @@ -198,18 +192,10 @@ def __init__(self, layout=None, model=None, ocr_strategy="auto"): self.model = model self.ocr_strategy = ocr_strategy - def ocr(self, text_block: MockTextBlock): + def ocr(self, text_block: MockTextRegion): return text_block.ocr_text -def test_interpret_text_block_use_ocr_when_text_symbols_cid(mock_image): - fake_text_block = Mock() - fake_text_block.text = "(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)" - with patch("unstructured_inference.inference.layout.ocr"): - layout.interpret_text_block(fake_text_block, mock_image) - layout.ocr.assert_called_once() - - @pytest.mark.parametrize( "text, expected", [("base", 0.0), ("", 0.0), ("(cid:2)", 1.0), ("(cid:1)a", 0.5), ("c(cid:1)ab", 0.25)], @@ -251,18 +237,15 @@ def filter_by(self, *args, **kwargs): [ ("no ocr", ["pieced", "together", "group"], "no ocr"), (None, ["pieced", "together", "group"], "pieced together group"), - (None, [None, None, "one"], "ocr ocr one"), ], ) def test_get_element_from_block(block_text, layout_texts, mock_image, expected_text): with patch("unstructured_inference.inference.layout.ocr", return_value="ocr"): - block = TextBlock(Rectangle(0, 0, 10, 10), text=block_text) - captured_layout = Layout( - [ - TextBlock(Rectangle(i + 1, i + 1, i + 2, i + 2), text=text) - for i, text in enumerate(layout_texts) - ] - ) + block = layout.TextRegion(0, 0, 10, 10, text=block_text) + captured_layout = [ + layout.TextRegion(i + 1, i + 1, i + 2, i + 2, text=text) + for i, text in enumerate(layout_texts) + ] assert ( layout.get_element_from_block(block, mock_image, captured_layout).text == expected_text ) @@ -270,7 +253,7 @@ def test_get_element_from_block(block_text, layout_texts, mock_image, expected_t def test_get_elements_from_block_raises(): with pytest.raises(ValueError): - block = TextBlock(Rectangle(0, 0, 10, 10), text=None) + block = layout.TextRegion(0, 0, 10, 10, text=None) layout.get_element_from_block(block, None, None) @@ -308,8 +291,8 @@ def test_from_file_raises_on_length_mismatch(monkeypatch): @pytest.mark.parametrize("idx", range(2)) def test_get_elements_from_layout(mock_page_layout, idx): page = MockPageLayout(layout=mock_page_layout) - block = mock_page_layout._blocks[idx].pad(3) - fixed_layout = Layout(blocks=[block]) + block = mock_page_layout[idx].pad(3) + fixed_layout = [block] elements = page.get_elements_from_layout(fixed_layout) assert elements[0].text == block.text @@ -342,7 +325,77 @@ def test_remove_control_characters(text, expected): assert layout.remove_control_characters(text) == expected -def test_interpret_called_when_filter_empty(mock_image): - with patch("unstructured_inference.inference.layout.interpret_text_block"): - layout.aggregate_by_block(MockTextBlock(), mock_image, MockLayout()) - layout.interpret_text_block.assert_called_once() +no_text_region = layout.TextRegion(0, 0, 100, 100) +text_region = layout.TextRegion(0, 0, 100, 100, text="test") +cid_text_region = layout.TextRegion(0, 0, 100, 100, text="(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)") +overlapping_rect = layout.ImageTextRegion(50, 50, 150, 150) +nonoverlapping_rect = layout.ImageTextRegion(150, 150, 200, 200) +populated_text_region = layout.TextRegion(50, 50, 60, 60, text="test") +unpopulated_text_region = layout.TextRegion(50, 50, 60, 60, text=None) + + +@pytest.mark.parametrize( + ("region", "text_objects", "image_objects", "ocr_strategy", "expected"), + [ + (no_text_region, [], [nonoverlapping_rect], "auto", False), + (no_text_region, [], [overlapping_rect], "auto", True), + (no_text_region, [], [], "auto", False), + (no_text_region, [populated_text_region], [nonoverlapping_rect], "auto", False), + (no_text_region, [populated_text_region], [overlapping_rect], "auto", False), + (no_text_region, [populated_text_region], [], "auto", False), + (no_text_region, [unpopulated_text_region], [nonoverlapping_rect], "auto", False), + (no_text_region, [unpopulated_text_region], [overlapping_rect], "auto", True), + (no_text_region, [unpopulated_text_region], [], "auto", False), + *list( + product( + [text_region], + [[], [populated_text_region], [unpopulated_text_region]], + [[], [nonoverlapping_rect], [overlapping_rect]], + ["auto"], + [False], + ) + ), + *list( + product( + [cid_text_region], + [[], [populated_text_region], [unpopulated_text_region]], + [[overlapping_rect]], + ["auto"], + [True], + ) + ), + *list( + product( + [no_text_region, text_region, cid_text_region], + [[], [populated_text_region], [unpopulated_text_region]], + [[], [nonoverlapping_rect], [overlapping_rect]], + ["force"], + [True], + ) + ), + *list( + product( + [no_text_region, text_region, cid_text_region], + [[], [populated_text_region], [unpopulated_text_region]], + [[], [nonoverlapping_rect], [overlapping_rect]], + ["never"], + [False], + ) + ), + ], +) +def test_ocr_image(region, text_objects, image_objects, ocr_strategy, expected): + assert layout.needs_ocr(region, text_objects, image_objects, ocr_strategy) is expected + + +def test_load_pdf(): + layouts, images = layout.load_pdf("sample-docs/loremipsum.pdf") + assert len(layouts) + assert len(images) + assert len(layouts) == len(images) + + +def test_load_pdf_with_images(): + layouts, _ = layout.load_pdf("sample-docs/loremipsum-flat.pdf") + first_page_layout = layouts[0] + assert any(isinstance(obj, layout.ImageTextRegion) for obj in first_page_layout) diff --git a/test_unstructured_inference/models/test_detectron2.py b/test_unstructured_inference/models/test_detectron2.py index 7d302759..60cd5a3a 100644 --- a/test_unstructured_inference/models/test_detectron2.py +++ b/test_unstructured_inference/models/test_detectron2.py @@ -11,11 +11,7 @@ def __init__(self, *args, **kwargs): self.kwargs = kwargs def detect(self, x): - return MockLayout() - - -class MockLayout: - pass + return [] def test_load_default_model(monkeypatch): @@ -46,4 +42,6 @@ def test_load_model(monkeypatch, config_path, model_path): def test_unstructured_detectron_model(): model = detectron2.UnstructuredDetectronModel() model.model = MockDetectron2LayoutModel() - assert isinstance(model(None), MockLayout) + result = model(None) + assert isinstance(result, list) + assert len(result) == 0 diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index b9b2517f..b42b08af 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.2.12" # pragma: no cover +__version__ = "0.2.13-dev0" # pragma: no cover diff --git a/unstructured_inference/api.py b/unstructured_inference/api.py index 2131d0d7..9a0a5483 100644 --- a/unstructured_inference/api.py +++ b/unstructured_inference/api.py @@ -18,6 +18,7 @@ async def layout_parsing( force_ocr=Form(default=False), # TODO(alan): Need a way to send model options to the model ): + """Route to proper filetype parser.""" if filetype not in VALID_FILETYPES: raise HTTPException(status.HTTP_404_NOT_FOUND) is_image = filetype == "image" @@ -46,4 +47,5 @@ async def layout_parsing( @app.get("/healthcheck", status_code=status.HTTP_200_OK) async def healthcheck(request: Request): + """Return healthy status""" return {"healthcheck": "HEALTHCHECK STATUS: EVERYTHING OK!"} diff --git a/unstructured_inference/inference/elements.py b/unstructured_inference/inference/elements.py new file mode 100644 index 00000000..7394c402 --- /dev/null +++ b/unstructured_inference/inference/elements.py @@ -0,0 +1,96 @@ +from __future__ import annotations +from copy import deepcopy +from dataclasses import dataclass +from typing import Optional + +from layoutparser.elements.layout import TextBlock + + +@dataclass +class Rectangle: + x1: int + y1: int + x2: int + y2: int + + def pad(self, padding: int): + """Increases (or decreases, if padding is negative) the size of the rectangle by extending + the boundary outward (resp. inward).""" + out_object = deepcopy(self) + out_object.x1 -= padding + out_object.y1 -= padding + out_object.x2 += padding + out_object.y2 += padding + return out_object + + @property + def width(self): + """Width of rectangle""" + return self.x2 - self.x1 + + @property + def height(self): + """Height of rectangle""" + return self.y2 - self.y1 + + def is_disjoint(self, other: Rectangle): + """Checks whether this rectangle is disjoint from another rectangle.""" + return ((self.x2 < other.x1) or (self.x1 > other.x2)) and ( + (self.y2 < other.y1) or (self.y1 > other.y2) + ) + + def intersects(self, other: Rectangle): + """Checks whether this rectangle intersects another rectangle.""" + return not self.is_disjoint(other) + + def is_in(self, other: Rectangle, error_margin: Optional[int] = None): + """Checks whether this rectangle is contained within another rectangle.""" + if error_margin is not None: + padded_other = other.pad(error_margin) + else: + padded_other = other + return all( + [ + (self.x1 >= padded_other.x1), + (self.x2 <= padded_other.x2), + (self.y1 >= padded_other.y1), + (self.y2 <= padded_other.y2), + ] + ) + + +@dataclass +class TextRegion(Rectangle): + text: Optional[str] = None + + def __str__(self) -> str: + return str(self.text) + + +class ImageTextRegion(TextRegion): + pass + + +@dataclass +class LayoutElement(TextRegion): + type: Optional[str] = None + + def to_dict(self) -> dict: + """Converts the class instance to dictionary form.""" + return self.__dict__ + + @classmethod + def from_region(cls, region: Rectangle): + """Create LayoutElement from superclass.""" + x1, y1, x2, y2 = region.x1, region.y1, region.x2, region.y2 + text = region.text if hasattr(region, "text") else None + type = region.type if hasattr(region, "type") else None + return cls(x1, y1, x2, y2, text, type) + + @classmethod + def from_lp_textblock(cls, textblock: TextBlock): + """Create LayoutElement from layoutparser TextBlock object.""" + x1, y1, x2, y2 = textblock.coordinates + text = textblock.text + type = textblock.type + return cls(x1, y1, x2, y2, text, type) diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py index 9a11663a..cc7f1cfb 100644 --- a/unstructured_inference/inference/layout.py +++ b/unstructured_inference/inference/layout.py @@ -1,17 +1,17 @@ from __future__ import annotations -from dataclasses import dataclass import os import re import tempfile from tqdm import tqdm from typing import List, Optional, Tuple, Union, BinaryIO import unicodedata -from layoutparser.io.pdf import load_pdf -from layoutparser.elements.layout_elements import TextBlock -from layoutparser.elements.layout import Layout + import numpy as np +import pdfplumber +import pdf2image from PIL import Image +from unstructured_inference.inference.elements import TextRegion, ImageTextRegion, LayoutElement from unstructured_inference.logger import logger import unstructured_inference.models.tesseract as tesseract from unstructured_inference.models.base import get_model @@ -24,32 +24,6 @@ ) -@dataclass -# NOTE(alan): I notice this has (almost?) the same structure as a layoutparser TextBlock. Maybe we -# don't need to make our own here? -class LayoutElement: - type: str - # NOTE(robinson) - The list contain two elements, each a tuple - # in format (x1,y1), the first the upper left corner and the second - # the right bottom corner - coordinates: List[Tuple[float, float]] - text: Optional[str] = None - - def __str__(self) -> str: - return str(self.text) - - def to_dict(self) -> dict: - return self.__dict__ - - def get_width(self) -> float: - # NOTE(benjamin) i.e: y2-y1 - return self.coordinates[1][0] - self.coordinates[0][0] - - def get_height(self) -> float: - # NOTE(benjamin) i.e: x2-x1 - return self.coordinates[1][1] - self.coordinates[0][1] - - class DocumentLayout: """Class for handling documents that are saved as .pdf files. For .pdf files, a document image analysis (DIA) model detects the layout of the page prior to extracting @@ -78,16 +52,12 @@ def from_file( cls, filename: str, model: Optional[UnstructuredModel] = None, - fixed_layouts: Optional[List[Optional[Layout]]] = None, + fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None, ocr_strategy: str = "auto", ) -> DocumentLayout: """Creates a DocumentLayout from a pdf file.""" - # NOTE(alan): For now the model is a Detectron2LayoutModel but in the future it should - # be an abstract class that supports some standard interface and can accomodate either - # a locally instantiated model or an API. Maybe even just a callable that accepts an - # image and returns a dict, or something. logger.info(f"Reading PDF for file: {filename} ...") - layouts, images = load_pdf(filename, load_images=True) + layouts, images = load_pdf(filename) if len(layouts) > len(images): raise RuntimeError( "Some images were not loaded. Check that poppler is installed and in your $PATH." @@ -114,7 +84,7 @@ def from_image_file( filename: str, model: Optional[UnstructuredModel] = None, ocr_strategy: str = "auto", - fixed_layout: Optional[Layout] = None, + fixed_layout: Optional[List[TextRegion]] = None, ) -> DocumentLayout: """Creates a DocumentLayout from an image file.""" logger.info(f"Reading image file: {filename} ...") @@ -138,7 +108,7 @@ def __init__( self, number: int, image: Image, - layout: Layout, + layout: Optional[List[TextRegion]], model: Optional[UnstructuredModel] = None, ocr_strategy: str = "auto", ): @@ -170,13 +140,12 @@ def get_elements(self, inplace=True) -> Optional[List[LayoutElement]]: return None return elements - def get_elements_from_layout(self, layout: Layout) -> List[LayoutElement]: + def get_elements_from_layout(self, layout: List[TextRegion]) -> List[LayoutElement]: """Uses the given Layout to separate the page text into elements, either extracting the text from the discovered layout blocks or from the image using OCR.""" # NOTE(robinson) - This orders the page from top to bottom. We'll need more # sophisticated ordering logic for more complicated layouts. - layout.sort(key=lambda element: element.coordinates[1], inplace=True) - # NOTE(benjamin): Creates a Pool for concurrent processing of image elements by OCR + layout.sort(key=lambda element: element.y1) elements = [] for e in tqdm(layout): elements.append(get_element_from_block(e, self.image, self.layout, self.ocr_strategy)) @@ -193,9 +162,9 @@ def from_image( cls, image, model: Optional[UnstructuredModel] = None, - layout: Optional[Layout] = None, + layout: Optional[List[TextRegion]] = None, ocr_strategy: str = "auto", - fixed_layout: Optional[Layout] = None, + fixed_layout: Optional[List[TextRegion]] = None, ): """Creates a PageLayout from an already-loaded PIL Image.""" page = cls(number=0, image=image, layout=layout, model=model, ocr_strategy=ocr_strategy) @@ -211,7 +180,7 @@ def process_data_with_model( model_name: Optional[str], is_image: bool = False, ocr_strategy: str = "auto", - fixed_layouts: Optional[List[Optional[Layout]]] = None, + fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None, ) -> DocumentLayout: """Processes pdf file in the form of a file handler (supporting a read method) into a DocumentLayout by using a model identified by model_name.""" @@ -233,7 +202,7 @@ def process_file_with_model( model_name: Optional[str], is_image: bool = False, ocr_strategy: str = "auto", - fixed_layouts: Optional[List[Optional[Layout]]] = None, + fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None, ) -> DocumentLayout: """Processes pdf file with name filename into a DocumentLayout by using a model identified by model_name.""" @@ -266,9 +235,9 @@ def is_cid_present(text: str) -> bool: def get_element_from_block( - block: TextBlock, + block: TextRegion, image: Optional[Image.Image] = None, - layout: Optional[Layout] = None, + pdf_objects: Optional[List[Union[TextRegion, ImageTextRegion]]] = None, ocr_strategy: str = "auto", ) -> LayoutElement: """Creates a LayoutElement from a given layout or image by finding all the text that lies within @@ -276,66 +245,51 @@ def get_element_from_block( if block.text is not None: # If block text is already populated, we'll assume it's correct text = block.text - elif layout is not None: - text = aggregate_by_block(block, image, layout, ocr_strategy) + elif pdf_objects is not None: + text = aggregate_by_block(block, image, pdf_objects, ocr_strategy) elif image is not None: - text = interpret_text_block(block, image, ocr_strategy) + # We don't have anything to go on but the image itself, so we use OCR + text = ocr(block, image) else: raise ValueError( "Got arguments image and layout as None, at least one must be populated to use for " "text extraction." ) - element = LayoutElement(type=block.type, text=text, coordinates=block.points.tolist()) + element = LayoutElement.from_region(block) + element.text = text return element def aggregate_by_block( - text_block: TextBlock, + text_region: TextRegion, image: Optional[Image.Image], - layout: Layout, + pdf_objects: List[Union[TextRegion, ImageTextRegion]], ocr_strategy: str = "auto", ) -> str: """Extracts the text aggregated from the elements of the given layout that lie within the given block.""" - filtered_blocks = layout.filter_by(text_block, center=True) - # NOTE(alan): For now, if none of the elements discovered by layoutparser are in the block - # we can try interpreting the whole block. This still doesn't handle edge cases, like when there - # are some text elements within the block, but there are image elements overlapping the block - # with text lying within the block. In this case the text in the image would likely be ignored. - if not filtered_blocks: - text = interpret_text_block(text_block, image, ocr_strategy) - return text - for little_block in filtered_blocks: - little_block.text = interpret_text_block(little_block, image, ocr_strategy) - text = " ".join([x for x in filtered_blocks.get_texts() if x]) - return text - - -def interpret_text_block( - text_block: TextBlock, image: Image.Image, ocr_strategy: str = "auto" -) -> str: - """Interprets the text in a TextBlock using OCR or the text attribute, according to the given - ocr_strategy.""" - # NOTE(robinson) - If the text attribute is None, that means the PDF isn't - # already OCR'd and we have to send the snippet out for OCRing. - - if (ocr_strategy == "force") or ( - ocr_strategy == "auto" and ((text_block.text is None) or cid_ratio(text_block.text) > 0.5) - ): - out_text = ocr(text_block, image) + word_objects = [obj for obj in pdf_objects if isinstance(obj, TextRegion)] + image_objects = [obj for obj in pdf_objects if isinstance(obj, ImageTextRegion)] + if image is not None and needs_ocr(text_region, word_objects, image_objects, ocr_strategy): + text = ocr(text_region, image) else: - out_text = "" if text_block.text is None else text_block.text - out_text = remove_control_characters(out_text) - return out_text + filtered_blocks = [obj for obj in pdf_objects if obj.is_in(text_region, error_margin=5)] + for little_block in filtered_blocks: + if image is not None and needs_ocr( + little_block, word_objects, image_objects, ocr_strategy + ): + little_block.text = ocr(little_block, image) + text = " ".join([x.text for x in filtered_blocks if x.text]) + text = remove_control_characters(text) + return text -def ocr(text_block: TextBlock, image: Image.Image) -> str: +def ocr(text_block: TextRegion, image: Image.Image) -> str: """Runs a cropped text block image through and OCR agent.""" logger.debug("Running OCR on text block ...") tesseract.load_agent() - image_array = np.array(image) - padded_block = text_block.pad(left=5, right=5, top=5, bottom=5) - cropped_image = padded_block.crop_image(image_array) + padded_block = text_block.pad(5) + cropped_image = image.crop((padded_block.x1, padded_block.y1, padded_block.x2, padded_block.y2)) return tesseract.ocr_agent.detect(cropped_image) @@ -343,3 +297,80 @@ def remove_control_characters(text: str) -> str: """Removes control characters from text.""" out_text = "".join(c for c in text if unicodedata.category(c)[0] != "C") return out_text + + +def load_pdf( + filename: str, + x_tolerance: Union[int, float] = 1.5, + y_tolerance: Union[int, float] = 2, + keep_blank_chars: bool = False, + use_text_flow: bool = False, + horizontal_ltr: bool = True, # Should words be read left-to-right? + vertical_ttb: bool = True, # Should vertical words be read top-to-bottom? + extra_attrs: Optional[List[str]] = None, + split_at_punctuation: Union[bool, str] = False, + dpi: int = 72, +) -> Tuple[List[List[TextRegion]], List[Image.Image]]: + """Loads the image and word objects from a pdf using pdfplumber and the image renderings of the + pdf pages using pdf2image""" + pdf_object = pdfplumber.open(filename) + layouts = [] + images = [] + for page in pdf_object.pages: + plumber_words = page.extract_words( + x_tolerance=x_tolerance, + y_tolerance=y_tolerance, + keep_blank_chars=keep_blank_chars, + use_text_flow=use_text_flow, + horizontal_ltr=horizontal_ltr, + vertical_ttb=vertical_ttb, + extra_attrs=extra_attrs, + split_at_punctuation=split_at_punctuation, + ) + word_objs = [ + TextRegion( + x1=word["x0"], y1=word["top"], x2=word["x1"], y2=word["bottom"], text=word["text"] + ) + for word in plumber_words + ] + image_objs = [ + ImageTextRegion(x1=image["x0"], y1=image["y0"], x2=image["x1"], y2=image["y1"]) + for image in page.images + ] + layout = word_objs + image_objs + layouts.append(layout) + + images = pdf2image.convert_from_path(filename, dpi=dpi) + return layouts, images + + +def needs_ocr( + region: TextRegion, + word_objects: List[TextRegion], + image_objects: List[ImageTextRegion], + ocr_strategy: str, +) -> bool: + """Logic to determine whether ocr is needed to extract text from given region.""" + if ocr_strategy == "force": + return True + elif ocr_strategy == "auto": + # If any image object overlaps with the region of interest, we have hope of getting some + # text from OCR. Otherwise, there's nothing there to find, no need to waste our time with + # OCR. + image_intersects = any(region.intersects(img_obj) for img_obj in image_objects) + if region.text is None: + # If the region has no text check if any images overlap with the region that might + # contain text. + if any(obj.is_in(region) and obj.text is not None for obj in word_objects): + # If there are word objects in the region, we defer to that rather than OCR + return False + else: + return image_intersects + elif cid_ratio(region.text) > 0.5: + # If the region has text, we should only have to OCR if too much of the text is + # uninterpretable. + return image_intersects + else: + return False + else: + return False diff --git a/unstructured_inference/models/detectron2.py b/unstructured_inference/models/detectron2.py index 27b3a7da..19f82e66 100644 --- a/unstructured_inference/models/detectron2.py +++ b/unstructured_inference/models/detectron2.py @@ -10,6 +10,7 @@ from huggingface_hub import hf_hub_download from unstructured_inference.logger import logger +from unstructured_inference.inference.elements import LayoutElement from unstructured_inference.models.unstructuredmodel import UnstructuredModel from unstructured_inference.utils import LazyDict, LazyEvaluateInfo @@ -59,8 +60,10 @@ class UnstructuredDetectronModel(UnstructuredModel): """Unstructured model wrapper for Detectron2LayoutModel.""" def predict(self, x: Image): + """Makes a prediction using detectron2 model.""" super().predict(x) - return self.model.detect(x) + prediction = self.model.detect(x) + return [LayoutElement.from_lp_textblock(block) for block in prediction] def initialize( self, diff --git a/unstructured_inference/models/donut.py b/unstructured_inference/models/donut.py index 342a781c..7ff5134d 100644 --- a/unstructured_inference/models/donut.py +++ b/unstructured_inference/models/donut.py @@ -12,6 +12,7 @@ class UnstructuredDonutModel(UnstructuredModel): """Unstructured model wrapper for Donut image transformer.""" def predict(self, x: Image): + """Make prediction using donut model""" super().predict(x) return self.run_prediction(x) @@ -45,6 +46,7 @@ def initialize( self.model.to(device) def run_prediction(self, x: Image): + """Internal prediction method.""" pixel_values = self.processor(x, return_tensors="pt").pixel_values decoder_input_ids = self.processor.tokenizer( self.task_prompt, add_special_tokens=False, return_tensors="pt" diff --git a/unstructured_inference/models/unstructuredmodel.py b/unstructured_inference/models/unstructuredmodel.py index 388e50b8..1082f6cc 100644 --- a/unstructured_inference/models/unstructuredmodel.py +++ b/unstructured_inference/models/unstructuredmodel.py @@ -23,6 +23,7 @@ def predict(self, x: Any) -> Any: pass # pragma: no cover def __call__(self, x: Any): + """Inference using function call interface.""" return self.predict(x) @abstractmethod diff --git a/unstructured_inference/models/yolox.py b/unstructured_inference/models/yolox.py index e2a09d5a..cf50d16c 100644 --- a/unstructured_inference/models/yolox.py +++ b/unstructured_inference/models/yolox.py @@ -6,11 +6,11 @@ from PIL import Image import cv2 from huggingface_hub import hf_hub_download -from layoutparser.elements.layout_elements import TextBlock, Rectangle -from layoutparser.elements.layout import Layout import numpy as np import onnxruntime +from typing import List +from unstructured_inference.inference.elements import LayoutElement from unstructured_inference.models.unstructuredmodel import UnstructuredModel from unstructured_inference.visualize import draw_bounding_boxes from unstructured_inference.utils import LazyDict, LazyEvaluateInfo @@ -47,17 +47,19 @@ class UnstructuredYoloXModel(UnstructuredModel): def predict(self, x: Image): + """Predict using YoloX model.""" super().predict(x) return self.image_processing(x) def initialize(self, model_path: str, label_map: dict): + """Start inference session for YoloX model.""" self.model = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"]) self.layout_classes = label_map def image_processing( self, image: Image = None, - ) -> Layout: + ) -> List[LayoutElement]: """Method runing YoloX for layout detection, returns a PageLayout parameters ---------- @@ -94,24 +96,25 @@ def image_processing( boxes_xyxy /= ratio dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) - blocks = [] + regions = [] for det in dets: # Each detection should have (x1,y1,x2,y2,probability,class) format # being (x1,y1) the top left and (x2,y2) the bottom right x1, y1, x2, y2, _, class_id = det.tolist() detected_class = self.layout_classes[int(class_id)] - block = TextBlock(type=detected_class, text=None, block=Rectangle(x1, y1, x2, y2)) + region = LayoutElement(x1, y1, x2, y2, text=None, type=detected_class) - blocks.append(block) + regions.append(region) - blocks.sort(key=lambda element: element.coordinates[1]) + regions.sort(key=lambda element: element.y1) - page_layout = Layout(blocks=blocks) # TODO(benjamin): encode image as base64? + page_layout = regions # TODO(benjamin): encode image as base64? return page_layout def annotate_image(self, image_fn, dets, out_fn): + """Draw bounding boxes and prediction metadata.""" origin_img = np.array(Image.open(image_fn)) final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] @@ -130,6 +133,7 @@ def annotate_image(self, image_fn, dets, out_fn): def preprocess(img, input_size, swap=(2, 0, 1)): + """Preprocess image data before YoloX inference.""" if len(img.shape) == 3: padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 else: @@ -149,6 +153,7 @@ def preprocess(img, input_size, swap=(2, 0, 1)): def demo_postprocess(outputs, img_size, p6=False): + """Postprocessing for YoloX model.""" grids = [] expanded_strides = []