Skip to content
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.2.2-dev0

* Add logic to use OCR when layout text is full of unknown characters

## 0.2.1

* Refactor to facilitate local inference
Expand Down
36 changes: 36 additions & 0 deletions test_unstructured_inference/inference/test_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,39 @@ def test_process_file_with_model(monkeypatch, mock_page_layout, model_name):
def test_process_file_with_model_raises_on_invalid_model_name():
with pytest.raises(models.UnknownModelException):
layout.process_file_with_model("", model_name="fake")


class MockPageLayout(layout.PageLayout):
def __init__(self, ocr_text):
self.ocr_text = ocr_text

def ocr(self, text_block):
return self.ocr_text


class MockTextBlock(lp.TextBlock):
def __init__(self, text):
self.text = text


def test_interpret_text_block_use_ocr_when_text_symbols_cid():
fake_text = "(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)"
fake_ocr = "ocrme"
fake_text_block = MockTextBlock(fake_text)
assert MockPageLayout(fake_ocr).interpret_text_block(fake_text_block) == fake_ocr


@pytest.mark.parametrize(
"text, expected",
[("base", 0.0), ("", 0.0), ("(cid:2)", 1.0), ("(cid:1)a", 0.5), ("c(cid:1)ab", 0.25)],
)
def test_cid_ratio(text, expected):
assert layout.cid_ratio(text) == expected


@pytest.mark.parametrize(
"text, expected",
[("base", False), ("(cid:2)", True), ("(cid:1234567890)", True), ("jkl;(cid:12)asdf", True)],
)
def test_is_cid_present(text, expected):
assert layout.is_cid_present(text) == expected
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.1" # pragma: no cover
__version__ = "0.2.2-dev0" # pragma: no cover
32 changes: 28 additions & 4 deletions unstructured_inference/inference/layout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
from dataclasses import dataclass
import re
import tempfile
from typing import List, Optional, Tuple, Union, BinaryIO

Expand Down Expand Up @@ -109,10 +110,7 @@ def get_elements(self, inplace=True) -> Optional[List[LayoutElement]]:
text_blocks = self.layout.filter_by(item, center=True)
text = str()
for text_block in text_blocks:
# NOTE(robinson) - If the text attribute is None, that means the PDF isn't
# already OCR'd and we have to send the snippet out for OCRing.
if text_block.text is None:
text_block.text = self.ocr(text_block)
text_block.text = self.interpret_text_block(text_block)
text = " ".join([x for x in text_blocks.get_texts() if x])

elements.append(
Expand All @@ -124,6 +122,16 @@ def get_elements(self, inplace=True) -> Optional[List[LayoutElement]]:
return None
return elements

def interpret_text_block(self, text_block: lp.TextBlock) -> str:
"""Interprets the text in a TextBlock."""
# NOTE(robinson) - If the text attribute is None, that means the PDF isn't
# already OCR'd and we have to send the snippet out for OCRing.
if (text_block.text is None) or cid_ratio(text_block.text) > 0.5:
out_text = self.ocr(text_block)
else:
out_text = text_block.text
return out_text

def ocr(self, text_block: lp.TextBlock) -> str:
"""Runs a cropped text block image through and OCR agent."""
logger.debug("Running OCR on text block ...")
Expand Down Expand Up @@ -156,3 +164,19 @@ def process_file_with_model(filename: str, model_name: str) -> DocumentLayout:
model = None if model_name is None else get_model(model_name)
layout = DocumentLayout.from_file(filename, model=model)
return layout


def cid_ratio(text: str) -> float:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if the cid pattern for unknown characters is specific to pdfminer? Or is that a universal convention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Gets ratio of unknown 'cid' characters extracted from text to all characters."""
if not is_cid_present(text):
return 0.0
cid_pattern = r"\(cid\:(\d+)\)"
unmatched, n_cid = re.subn(cid_pattern, "", text)
total = n_cid + len(unmatched)
return n_cid / total


def is_cid_present(text: str) -> bool:
if len(text) < len("(cid:x)"):
return False
return text.find("(cid:") != -1