diff --git a/CHANGELOG.md b/CHANGELOG.md index 78599002..ef590252 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.2.2-dev0 + +* Add logic to use OCR when layout text is full of unknown characters + ## 0.2.1 * Refactor to facilitate local inference diff --git a/test_unstructured_inference/inference/test_layout.py b/test_unstructured_inference/inference/test_layout.py index 05c68736..fb7d390e 100644 --- a/test_unstructured_inference/inference/test_layout.py +++ b/test_unstructured_inference/inference/test_layout.py @@ -176,3 +176,39 @@ def test_process_file_with_model(monkeypatch, mock_page_layout, model_name): def test_process_file_with_model_raises_on_invalid_model_name(): with pytest.raises(models.UnknownModelException): layout.process_file_with_model("", model_name="fake") + + +class MockPageLayout(layout.PageLayout): + def __init__(self, ocr_text): + self.ocr_text = ocr_text + + def ocr(self, text_block): + return self.ocr_text + + +class MockTextBlock(lp.TextBlock): + def __init__(self, text): + self.text = text + + +def test_interpret_text_block_use_ocr_when_text_symbols_cid(): + fake_text = "(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)" + fake_ocr = "ocrme" + fake_text_block = MockTextBlock(fake_text) + assert MockPageLayout(fake_ocr).interpret_text_block(fake_text_block) == fake_ocr + + +@pytest.mark.parametrize( + "text, expected", + [("base", 0.0), ("", 0.0), ("(cid:2)", 1.0), ("(cid:1)a", 0.5), ("c(cid:1)ab", 0.25)], +) +def test_cid_ratio(text, expected): + assert layout.cid_ratio(text) == expected + + +@pytest.mark.parametrize( + "text, expected", + [("base", False), ("(cid:2)", True), ("(cid:1234567890)", True), ("jkl;(cid:12)asdf", True)], +) +def test_is_cid_present(text, expected): + assert layout.is_cid_present(text) == expected diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 9aa97038..d8249901 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.2.1" # pragma: no cover +__version__ = "0.2.2-dev0" # pragma: no cover diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py index f3e37c85..675eb48e 100644 --- a/unstructured_inference/inference/layout.py +++ b/unstructured_inference/inference/layout.py @@ -1,5 +1,6 @@ from __future__ import annotations from dataclasses import dataclass +import re import tempfile from typing import List, Optional, Tuple, Union, BinaryIO @@ -109,10 +110,7 @@ def get_elements(self, inplace=True) -> Optional[List[LayoutElement]]: text_blocks = self.layout.filter_by(item, center=True) text = str() for text_block in text_blocks: - # NOTE(robinson) - If the text attribute is None, that means the PDF isn't - # already OCR'd and we have to send the snippet out for OCRing. - if text_block.text is None: - text_block.text = self.ocr(text_block) + text_block.text = self.interpret_text_block(text_block) text = " ".join([x for x in text_blocks.get_texts() if x]) elements.append( @@ -124,6 +122,16 @@ def get_elements(self, inplace=True) -> Optional[List[LayoutElement]]: return None return elements + def interpret_text_block(self, text_block: lp.TextBlock) -> str: + """Interprets the text in a TextBlock.""" + # NOTE(robinson) - If the text attribute is None, that means the PDF isn't + # already OCR'd and we have to send the snippet out for OCRing. + if (text_block.text is None) or cid_ratio(text_block.text) > 0.5: + out_text = self.ocr(text_block) + else: + out_text = text_block.text + return out_text + def ocr(self, text_block: lp.TextBlock) -> str: """Runs a cropped text block image through and OCR agent.""" logger.debug("Running OCR on text block ...") @@ -156,3 +164,19 @@ def process_file_with_model(filename: str, model_name: str) -> DocumentLayout: model = None if model_name is None else get_model(model_name) layout = DocumentLayout.from_file(filename, model=model) return layout + + +def cid_ratio(text: str) -> float: + """Gets ratio of unknown 'cid' characters extracted from text to all characters.""" + if not is_cid_present(text): + return 0.0 + cid_pattern = r"\(cid\:(\d+)\)" + unmatched, n_cid = re.subn(cid_pattern, "", text) + total = n_cid + len(unmatched) + return n_cid / total + + +def is_cid_present(text: str) -> bool: + if len(text) < len("(cid:x)"): + return False + return text.find("(cid:") != -1