diff --git a/CHANGELOG.md b/CHANGELOG.md index b22bcbd1..a7631e08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ -## 0.3.3-dev2 +## 0.4.0 +* Added logic to partition granular elements (words, characters) by proximity +* Text extraction is now delegated to text regions rather than being handled centrally * Fixed embedded image coordinates being interpreted differently than embedded text coordinates * Update to how dependencies are being handled * Update detectron2 version diff --git a/test_unstructured_inference/inference/test_layout.py b/test_unstructured_inference/inference/test_layout.py index ba9538f3..d13641f3 100644 --- a/test_unstructured_inference/inference/test_layout.py +++ b/test_unstructured_inference/inference/test_layout.py @@ -8,6 +8,7 @@ from PIL import Image import unstructured_inference.inference.layout as layout +import unstructured_inference.inference.elements as elements import unstructured_inference.models.base as models import unstructured_inference.models.detectron2 as detectron2 import unstructured_inference.models.tesseract as tesseract @@ -20,9 +21,9 @@ def mock_image(): @pytest.fixture def mock_page_layout(): - text_block = layout.TextRegion(2, 4, 6, 8, text="A very repetitive narrative. " * 10) + text_block = layout.EmbeddedTextRegion(2, 4, 6, 8, text="A very repetitive narrative. " * 10) - title_block = layout.TextRegion(1, 2, 3, 4, text="A Catchy Title") + title_block = layout.EmbeddedTextRegion(1, 2, 3, 4, text="A Catchy Title") return [text_block, title_block] @@ -49,7 +50,7 @@ def detect(self, *args): image = Image.fromarray(np.random.randint(12, 24, (40, 40)), mode="RGB") text_block = layout.TextRegion(1, 2, 3, 4, text=None) - assert layout.ocr(text_block, image=image) == mock_text + assert elements.ocr(text_block, image=image) == mock_text class MockLayoutModel: @@ -69,12 +70,12 @@ def test_get_page_elements(monkeypatch, mock_page_layout): number=0, image=image, layout=mock_page_layout, model=MockLayoutModel(mock_page_layout) ) - elements = page.get_elements(inplace=False) + elements = page.get_elements_with_model(inplace=False) assert str(elements[0]) == "A Catchy Title" assert str(elements[1]).startswith("A very repetitive narrative.") - page.get_elements(inplace=True) + page.get_elements_with_model(inplace=True) assert elements == page.elements @@ -95,13 +96,13 @@ def test_get_page_elements_with_ocr(monkeypatch): doc_layout = [text_block, image_block] monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True) - monkeypatch.setattr(layout, "ocr", lambda *args: "An Even Catchier Title") + monkeypatch.setattr(elements, "ocr", lambda *args: "An Even Catchier Title") image = Image.fromarray(np.random.randint(12, 14, size=(40, 10, 3)), mode="RGB") page = layout.PageLayout( number=0, image=image, layout=doc_layout, model=MockLayoutModel(doc_layout) ) - page.get_elements() + page.get_elements_with_model() assert str(page) == "\n\nAn Even Catchier Title" @@ -174,7 +175,7 @@ def tolist(self): return [1, 2, 3, 4] -class MockTextRegion(layout.TextRegion): +class MockEmbeddedTextRegion(layout.EmbeddedTextRegion): def __init__(self, type=None, text=None, ocr_text=None): self.type = type self.text = text @@ -193,7 +194,7 @@ def __init__(self, layout=None, model=None, ocr_strategy="auto", extract_tables= self.ocr_strategy = ocr_strategy self.extract_tables = extract_tables - def ocr(self, text_block: MockTextRegion): + def ocr(self, text_block: MockEmbeddedTextRegion): return text_block.ocr_text @@ -202,7 +203,7 @@ def ocr(self, text_block: MockTextRegion): [("base", 0.0), ("", 0.0), ("(cid:2)", 1.0), ("(cid:1)a", 0.5), ("c(cid:1)ab", 0.25)], ) def test_cid_ratio(text, expected): - assert layout.cid_ratio(text) == expected + assert elements.cid_ratio(text) == expected @pytest.mark.parametrize( @@ -210,7 +211,7 @@ def test_cid_ratio(text, expected): [("base", False), ("(cid:2)", True), ("(cid:1234567890)", True), ("jkl;(cid:12)asdf", True)], ) def test_is_cid_present(text, expected): - assert layout.is_cid_present(text) == expected + assert elements.is_cid_present(text) == expected class MockLayout: @@ -241,7 +242,7 @@ def filter_by(self, *args, **kwargs): ], ) def test_get_element_from_block(block_text, layout_texts, mock_image, expected_text): - with patch("unstructured_inference.inference.layout.ocr", return_value="ocr"): + with patch("unstructured_inference.inference.elements.ocr", return_value="ocr"): block = layout.TextRegion(0, 0, 10, 10, text=block_text) captured_layout = [ layout.TextRegion(i + 1, i + 1, i + 2, i + 2, text=text) @@ -263,7 +264,7 @@ def test_from_image_file(monkeypatch, mock_page_layout, filetype): def mock_get_elements(self, *args, **kwargs): self.elements = [mock_page_layout] - monkeypatch.setattr(layout.PageLayout, "get_elements", mock_get_elements) + monkeypatch.setattr(layout.PageLayout, "get_elements_with_model", mock_get_elements) elements = ( layout.DocumentLayout.from_image_file(f"sample-docs/loremipsum.{filetype}") .pages[0] @@ -301,12 +302,12 @@ def test_get_elements_from_layout(mock_page_layout, idx): @pytest.mark.parametrize( "fixed_layouts, called_method, not_called_method", [ - ([MockLayout()], "get_elements_from_layout", "get_elements"), - (None, "get_elements", "get_elements_from_layout"), + ([MockLayout()], "get_elements_from_layout", "get_elements_with_model"), + (None, "get_elements_with_model", "get_elements_from_layout"), ], ) def test_from_file_fixed_layout(fixed_layouts, called_method, not_called_method): - with patch.object(layout.PageLayout, "get_elements", return_value=[]), patch.object( + with patch.object(layout.PageLayout, "get_elements_with_model", return_value=[]), patch.object( layout.PageLayout, "get_elements_from_layout", return_value=[] ): layout.DocumentLayout.from_file("sample-docs/loremipsum.pdf", fixed_layouts=fixed_layouts) @@ -323,35 +324,46 @@ def test_invalid_ocr_strategy_raises(mock_image): ("text", "expected"), [("a\ts\x0cd\nfas\fd\rf\b", "asdfasdf"), ("\"'\\", "\"'\\")] ) def test_remove_control_characters(text, expected): - assert layout.remove_control_characters(text) == expected + assert elements.remove_control_characters(text) == expected -no_text_region = layout.TextRegion(0, 0, 100, 100) -text_region = layout.TextRegion(0, 0, 100, 100, text="test") -cid_text_region = layout.TextRegion(0, 0, 100, 100, text="(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)") +no_text_region = layout.EmbeddedTextRegion(0, 0, 100, 100) +text_region = layout.EmbeddedTextRegion(0, 0, 100, 100, text="test") +cid_text_region = layout.EmbeddedTextRegion( + 0, 0, 100, 100, text="(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)" +) overlapping_rect = layout.ImageTextRegion(50, 50, 150, 150) nonoverlapping_rect = layout.ImageTextRegion(150, 150, 200, 200) -populated_text_region = layout.TextRegion(50, 50, 60, 60, text="test") -unpopulated_text_region = layout.TextRegion(50, 50, 60, 60, text=None) +populated_text_region = layout.EmbeddedTextRegion(50, 50, 60, 60, text="test") +unpopulated_text_region = layout.EmbeddedTextRegion(50, 50, 60, 60, text=None) @pytest.mark.parametrize( - ("region", "text_objects", "image_objects", "ocr_strategy", "expected"), + ("region", "objects", "ocr_strategy", "expected"), [ - (no_text_region, [], [nonoverlapping_rect], "auto", False), - (no_text_region, [], [overlapping_rect], "auto", True), - (no_text_region, [], [], "auto", False), - (no_text_region, [populated_text_region], [nonoverlapping_rect], "auto", False), - (no_text_region, [populated_text_region], [overlapping_rect], "auto", False), - (no_text_region, [populated_text_region], [], "auto", False), - (no_text_region, [unpopulated_text_region], [nonoverlapping_rect], "auto", False), - (no_text_region, [unpopulated_text_region], [overlapping_rect], "auto", True), - (no_text_region, [unpopulated_text_region], [], "auto", False), + (no_text_region, [nonoverlapping_rect], "auto", False), + (no_text_region, [overlapping_rect], "auto", True), + (no_text_region, [], "auto", False), + (no_text_region, [populated_text_region, nonoverlapping_rect], "auto", False), + (no_text_region, [populated_text_region, overlapping_rect], "auto", False), + (no_text_region, [populated_text_region], "auto", False), + (no_text_region, [unpopulated_text_region, nonoverlapping_rect], "auto", False), + (no_text_region, [unpopulated_text_region, overlapping_rect], "auto", True), + (no_text_region, [unpopulated_text_region], "auto", False), *list( product( [text_region], - [[], [populated_text_region], [unpopulated_text_region]], - [[], [nonoverlapping_rect], [overlapping_rect]], + [ + [], + [populated_text_region], + [unpopulated_text_region], + [nonoverlapping_rect], + [overlapping_rect], + [populated_text_region, nonoverlapping_rect], + [populated_text_region, overlapping_rect], + [unpopulated_text_region, nonoverlapping_rect], + [unpopulated_text_region, overlapping_rect], + ], ["auto"], [False], ) @@ -359,8 +371,14 @@ def test_remove_control_characters(text, expected): *list( product( [cid_text_region], - [[], [populated_text_region], [unpopulated_text_region]], - [[overlapping_rect]], + [ + [], + [populated_text_region], + [unpopulated_text_region], + [overlapping_rect], + [populated_text_region, overlapping_rect], + [unpopulated_text_region, overlapping_rect], + ], ["auto"], [True], ) @@ -368,8 +386,17 @@ def test_remove_control_characters(text, expected): *list( product( [no_text_region, text_region, cid_text_region], - [[], [populated_text_region], [unpopulated_text_region]], - [[], [nonoverlapping_rect], [overlapping_rect]], + [ + [], + [populated_text_region], + [unpopulated_text_region], + [nonoverlapping_rect], + [overlapping_rect], + [populated_text_region, nonoverlapping_rect], + [populated_text_region, overlapping_rect], + [unpopulated_text_region, nonoverlapping_rect], + [unpopulated_text_region, overlapping_rect], + ], ["force"], [True], ) @@ -377,16 +404,25 @@ def test_remove_control_characters(text, expected): *list( product( [no_text_region, text_region, cid_text_region], - [[], [populated_text_region], [unpopulated_text_region]], - [[], [nonoverlapping_rect], [overlapping_rect]], + [ + [], + [populated_text_region], + [unpopulated_text_region], + [nonoverlapping_rect], + [overlapping_rect], + [populated_text_region, nonoverlapping_rect], + [populated_text_region, overlapping_rect], + [unpopulated_text_region, nonoverlapping_rect], + [unpopulated_text_region, overlapping_rect], + ], ["never"], [False], ) ), ], ) -def test_ocr_image(region, text_objects, image_objects, ocr_strategy, expected): - assert layout.needs_ocr(region, text_objects, image_objects, ocr_strategy) is expected +def test_ocr_image(region, objects, ocr_strategy, expected): + assert elements.needs_ocr(region, objects, ocr_strategy) is expected def test_load_pdf(): diff --git a/test_unstructured_inference/test_elements.py b/test_unstructured_inference/test_elements.py new file mode 100644 index 00000000..a86abd16 --- /dev/null +++ b/test_unstructured_inference/test_elements.py @@ -0,0 +1,84 @@ +from random import randint +from unstructured_inference.inference import elements +from unstructured_inference.inference.layout import load_pdf + + +def intersect_brute(rect1, rect2): + return any( + (rect2.x1 <= x <= rect2.x2) and (rect2.y1 <= y <= rect2.y2) + for x in range(rect1.x1, rect1.x2 + 1) + for y in range(rect1.y1, rect1.y2 + 1) + ) + + +def rand_rect(size=10): + x1 = randint(0, 30 - size) + y1 = randint(0, 30 - size) + return elements.Rectangle(x1, y1, x1 + size, y1 + size) + + +def test_intersects_overlap(): + for _ in range(1000): + rect1 = rand_rect() + rect2 = rand_rect() + assert intersect_brute(rect1, rect2) == rect1.intersects(rect2) == rect2.intersects(rect1) + + +def test_intersects_subset(): + for _ in range(1000): + rect1 = rand_rect() + rect2 = rand_rect(20) + assert intersect_brute(rect1, rect2) == rect1.intersects(rect2) == rect2.intersects(rect1) + + +def test_intersection_of_lots_of_rects(): + for _ in range(1000): + n_rects = 10 + rects = [rand_rect(6) for _ in range(n_rects)] + intersection_mtx = elements.intersections(*rects) + for i in range(n_rects): + for j in range(n_rects): + assert ( + intersect_brute(rects[i], rects[j]) + == intersection_mtx[i, j] + == intersection_mtx[j, i] + ) + + +def test_rectangle_width_height(): + for _ in range(1000): + x1 = randint(0, 50) + x2 = randint(x1 + 1, 100) + y1 = randint(0, 50) + y2 = randint(y1 + 1, 100) + rect = elements.Rectangle(x1, y1, x2, y2) + assert rect.width == x2 - x1 + assert rect.height == y2 - y1 + + +def test_minimal_containing_rect(): + for _ in range(1000): + rect1 = rand_rect() + rect2 = rand_rect() + big_rect = elements.minimal_containing_region(rect1, rect2) + for decrease_attr in ["x1", "y1", "x2", "y2"]: + almost_as_big_rect = rand_rect() + mod = 1 if decrease_attr.endswith("1") else -1 + for attr in ["x1", "y1", "x2", "y2"]: + if attr == decrease_attr: + setattr(almost_as_big_rect, attr, getattr(big_rect, attr) + mod) + else: + setattr(almost_as_big_rect, attr, getattr(big_rect, attr)) + assert not rect1.is_in(almost_as_big_rect) or not rect2.is_in(almost_as_big_rect) + + assert rect1.is_in(big_rect) + assert rect2.is_in(big_rect) + + +def test_partition_groups_from_regions(): + words, _ = load_pdf("sample-docs/layout-parser-paper.pdf") + groups = elements.partition_groups_from_regions(words[0]) + assert len(groups) == 9 + sorted_groups = sorted(groups, key=lambda group: group[0].y1) + text = "".join([el.text for el in sorted_groups[-1]]) + assert text.startswith("Deep") diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 628a64b5..88e4521e 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.3.3-dev2" # pragma: no cover +__version__ = "0.4.0" # pragma: no cover diff --git a/unstructured_inference/inference/elements.py b/unstructured_inference/inference/elements.py index 4cc4c592..22914eba 100644 --- a/unstructured_inference/inference/elements.py +++ b/unstructured_inference/inference/elements.py @@ -1,9 +1,24 @@ from __future__ import annotations from copy import deepcopy from dataclasses import dataclass -from typing import Optional, Union +import re +from typing import Optional, Union, Sequence, List +import unicodedata -from layoutparser.elements.layout import TextBlock +import numpy as np +from PIL import Image +from scipy.sparse.csgraph import connected_components + +from unstructured_inference.logger import logger +from unstructured_inference.models import tesseract + +# When extending the boundaries of a PDF object for the purpose of determining which other elements +# should be considered in the same text region, we use a relative distance based on some fraction of +# the block height (typically character height). This is the fraction used for the horizontal +# extension applied to the left and right sides. +H_PADDING_COEF = 0.4 +# Same as above but the vertical extension. +V_PADDING_COEF = 0.3 @dataclass @@ -13,37 +28,47 @@ class Rectangle: x2: Union[int, float] y2: Union[int, float] - def pad(self, padding: int): + def pad(self, padding: Union[int, float]): """Increases (or decreases, if padding is negative) the size of the rectangle by extending the boundary outward (resp. inward).""" + out_object = self.hpad(padding).vpad(padding) + return out_object + + def hpad(self, padding: Union[int, float]): + """Increases (or decreases, if padding is negative) the size of the rectangle by extending + the left and right sides of the boundary outward (resp. inward).""" out_object = deepcopy(self) out_object.x1 -= padding - out_object.y1 -= padding out_object.x2 += padding + return out_object + + def vpad(self, padding: Union[int, float]): + """Increases (or decreases, if padding is negative) the size of the rectangle by extending + the top and bottom of the boundary outward (resp. inward).""" + out_object = deepcopy(self) + out_object.y1 -= padding out_object.y2 += padding return out_object @property - def width(self): + def width(self) -> Union[int, float]: """Width of rectangle""" return self.x2 - self.x1 @property - def height(self): + def height(self) -> Union[int, float]: """Height of rectangle""" return self.y2 - self.y1 - def is_disjoint(self, other: Rectangle): + def is_disjoint(self, other: Rectangle) -> bool: """Checks whether this rectangle is disjoint from another rectangle.""" - return ((self.x2 < other.x1) or (self.x1 > other.x2)) and ( - (self.y2 < other.y1) or (self.y1 > other.y2) - ) + return not self.intersects(other) - def intersects(self, other: Rectangle): + def intersects(self, other: Rectangle) -> bool: """Checks whether this rectangle intersects another rectangle.""" - return not self.is_disjoint(other) + return intersections(self, other)[0, 1] - def is_in(self, other: Rectangle, error_margin: Optional[int] = None): + def is_in(self, other: Rectangle, error_margin: Optional[Union[int, float]] = None) -> bool: """Checks whether this rectangle is contained within another rectangle.""" if error_margin is not None: padded_other = other.pad(error_margin) @@ -64,6 +89,66 @@ def coordinates(self): return ((self.x1, self.y1), (self.x1, self.y2), (self.x2, self.y2), (self.x2, self.y1)) +def minimal_containing_region(*regions: Rectangle) -> Rectangle: + """Returns the smallest rectangular region that contains all regions passed""" + x1 = min(region.x1 for region in regions) + y1 = min(region.y1 for region in regions) + x2 = max(region.x2 for region in regions) + y2 = max(region.y2 for region in regions) + + # Return most specialized class of which that every region is a subclass + def least_common_superclass(*instances): + mros = (type(ins).mro() for ins in instances) + mro = next(mros) + common = set(mro).intersection(*mros) + return next((x for x in mro if x in common), Rectangle) + + cls = least_common_superclass(*regions) + return cls(x1, y1, x2, y2) + + +def partition_groups_from_regions(regions: Sequence[Rectangle]) -> List[List[Rectangle]]: + """Partitions regions into groups of regions based on proximity. Returns list of lists of + regions, each list corresponding with a group""" + padded_regions = [ + r.vpad(r.height * V_PADDING_COEF).hpad(r.height * H_PADDING_COEF) for r in regions + ] + + intersection_mtx = intersections(*padded_regions) + + _, group_nums = connected_components(intersection_mtx) + groups: List[List[Rectangle]] = [[] for _ in range(max(group_nums) + 1)] + for region, group_num in zip(regions, group_nums): + groups[group_num].append(region) + + return groups + + +def intersections(*rects: Rectangle): + """Returns a square boolean matrix of intersections of an arbitrary number of rectangles, i.e. + the ijth entry of the matrix is True if and only if the ith Rectangle and jth Rectangle + intersect.""" + coords = np.stack([[[r.x1, r.y1], [r.x2, r.y2]] for r in rects], axis=-1) + + (x1s, y1s), (x2s, y2s) = coords + + # Use broadcasting to get comparison matrices. + # For Rectangles r1 and r2, any of the following conditions makes the rectangles disjoint: + # r1.x1 > r2.x2 + # r1.y1 > r2.y2 + # r2.x1 > r1.x2 + # r2.y1 > r1.y2 + # Then we take the complement (~) of the disjointness matrix to get the intersection matrix. + intersections = ~( + (x1s[None] > x2s[..., None]) + | (y1s[None] > y2s[..., None]) + | (x1s[None] > x2s[..., None]).T + | (y1s[None] > y2s[..., None]).T + ) + + return intersections + + @dataclass class TextRegion(Rectangle): text: Optional[str] = None @@ -71,36 +156,146 @@ class TextRegion(Rectangle): def __str__(self) -> str: return str(self.text) + def extract_text( + self, + objects: Optional[List[TextRegion]], + image: Optional[Image.Image] = None, + extract_tables: bool = False, + ocr_strategy: str = "auto", + ) -> str: + """Extracts text contained in region.""" + if self.text is not None: + # If block text is already populated, we'll assume it's correct + text = self.text + elif objects is not None: + text = aggregate_by_block(self, image, objects, ocr_strategy) + elif image is not None: + if ocr_strategy != "never": + # We don't have anything to go on but the image itself, so we use OCR + text = ocr(self, image) + else: + text = "" + else: + raise ValueError( + "Got arguments image and layout as None, at least one must be populated to use for " + "text extraction." + ) + return text + + +class EmbeddedTextRegion(TextRegion): + def extract_text( + self, + objects: Optional[List[TextRegion]], + image: Optional[Image.Image] = None, + extract_tables: bool = False, + ocr_strategy: str = "auto", + ) -> str: + """Extracts text contained in region.""" + if self.text is None: + return "" + else: + return self.text + class ImageTextRegion(TextRegion): - pass + def extract_text( + self, + objects: Optional[List[TextRegion]], + image: Optional[Image.Image] = None, + extract_tables: bool = False, + ocr_strategy: str = "auto", + ) -> str: + """Extracts text contained in region.""" + if self.text is None: + if ocr_strategy == "never" or image is None: + return "" + else: + return ocr(self, image) + else: + return super().extract_text(objects, image, extract_tables, ocr_strategy) -@dataclass -class LayoutElement(TextRegion): - type: Optional[str] = None - - def to_dict(self) -> dict: - """Converts the class instance to dictionary form.""" - out_dict = { - "coordinates": self.coordinates, - "text": self.text, - "type": self.type, - } - return out_dict - - @classmethod - def from_region(cls, region: Rectangle): - """Create LayoutElement from superclass.""" - x1, y1, x2, y2 = region.x1, region.y1, region.x2, region.y2 - text = region.text if hasattr(region, "text") else None - type = region.type if hasattr(region, "type") else None - return cls(x1, y1, x2, y2, text, type) - - @classmethod - def from_lp_textblock(cls, textblock: TextBlock): - """Create LayoutElement from layoutparser TextBlock object.""" - x1, y1, x2, y2 = textblock.coordinates - text = textblock.text - type = textblock.type - return cls(x1, y1, x2, y2, text, type) +def ocr(text_block: TextRegion, image: Image.Image) -> str: + """Runs a cropped text block image through and OCR agent.""" + logger.debug("Running OCR on text block ...") + tesseract.load_agent() + padded_block = text_block.pad(12) + cropped_image = image.crop((padded_block.x1, padded_block.y1, padded_block.x2, padded_block.y2)) + return tesseract.ocr_agent.detect(cropped_image) + + +def needs_ocr( + region: TextRegion, + pdf_objects: List[TextRegion], + ocr_strategy: str, +) -> bool: + """Logic to determine whether ocr is needed to extract text from given region.""" + if ocr_strategy == "force": + return True + elif ocr_strategy == "auto": + image_objects = [obj for obj in pdf_objects if isinstance(obj, ImageTextRegion)] + word_objects = [obj for obj in pdf_objects if isinstance(obj, EmbeddedTextRegion)] + # If any image object overlaps with the region of interest, we have hope of getting some + # text from OCR. Otherwise, there's nothing there to find, no need to waste our time with + # OCR. + image_intersects = any(region.intersects(img_obj) for img_obj in image_objects) + if region.text is None: + # If the region has no text check if any images overlap with the region that might + # contain text. + if any(obj.is_in(region) and obj.text is not None for obj in word_objects): + # If there are word objects in the region, we defer to that rather than OCR + return False + else: + return image_intersects + elif cid_ratio(region.text) > 0.5: + # If the region has text, we should only have to OCR if too much of the text is + # uninterpretable. + return True + else: + return False + else: + return False + + +def aggregate_by_block( + text_region: TextRegion, + image: Optional[Image.Image], + pdf_objects: List[TextRegion], + ocr_strategy: str = "auto", +) -> str: + """Extracts the text aggregated from the elements of the given layout that lie within the given + block.""" + if image is not None and needs_ocr(text_region, pdf_objects, ocr_strategy): + text = ocr(text_region, image) + else: + filtered_blocks = [obj for obj in pdf_objects if obj.is_in(text_region, error_margin=5)] + for little_block in filtered_blocks: + if image is not None and needs_ocr(little_block, pdf_objects, ocr_strategy): + little_block.text = ocr(little_block, image) + text = " ".join([x.text for x in filtered_blocks if x.text]) + text = remove_control_characters(text) + return text + + +def cid_ratio(text: str) -> float: + """Gets ratio of unknown 'cid' characters extracted from text to all characters.""" + if not is_cid_present(text): + return 0.0 + cid_pattern = r"\(cid\:(\d+)\)" + unmatched, n_cid = re.subn(cid_pattern, "", text) + total = n_cid + len(unmatched) + return n_cid / total + + +def is_cid_present(text: str) -> bool: + """Checks if a cid code is present in a text selection.""" + if len(text) < len("(cid:x)"): + return False + return text.find("(cid:") != -1 + + +def remove_control_characters(text: str) -> str: + """Removes control characters from text.""" + out_text = "".join(c for c in text if unicodedata.category(c)[0] != "C") + return out_text diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py index 71f8f832..d08ccadc 100644 --- a/unstructured_inference/inference/layout.py +++ b/unstructured_inference/inference/layout.py @@ -1,19 +1,20 @@ from __future__ import annotations import os -import re import tempfile from typing import List, Optional, Tuple, Union, BinaryIO -import unicodedata import numpy as np import pdfplumber import pdf2image from PIL import Image -from unstructured_inference.inference.elements import TextRegion, ImageTextRegion, LayoutElement +from unstructured_inference.inference.elements import ( + TextRegion, + EmbeddedTextRegion, + ImageTextRegion, +) +from unstructured_inference.inference.layoutelement import LayoutElement from unstructured_inference.logger import logger -import unstructured_inference.models.tesseract as tesseract -import unstructured_inference.models.tables as tables from unstructured_inference.models.base import get_model from unstructured_inference.models.unstructuredmodel import UnstructuredModel @@ -135,7 +136,7 @@ def __init__( def __str__(self) -> str: return "\n\n".join([str(element) for element in self.elements]) - def get_elements(self, inplace=True) -> Optional[List[LayoutElement]]: + def get_elements_with_model(self, inplace=True) -> Optional[List[LayoutElement]]: """Uses specified model to detect the elements on the page.""" logger.info("Detecting page elements ...") if self.model is None: @@ -190,7 +191,7 @@ def from_image( extract_tables=extract_tables, ) if fixed_layout is None: - page.get_elements() + page.get_elements_with_model() else: page.elements = page.get_elements_from_layout(fixed_layout) return page @@ -247,101 +248,22 @@ def process_file_with_model( return layout -def cid_ratio(text: str) -> float: - """Gets ratio of unknown 'cid' characters extracted from text to all characters.""" - if not is_cid_present(text): - return 0.0 - cid_pattern = r"\(cid\:(\d+)\)" - unmatched, n_cid = re.subn(cid_pattern, "", text) - total = n_cid + len(unmatched) - return n_cid / total - - -def is_cid_present(text: str) -> bool: - """Checks if a cid code is present in a text selection.""" - if len(text) < len("(cid:x)"): - return False - return text.find("(cid:") != -1 - - def get_element_from_block( block: TextRegion, image: Optional[Image.Image] = None, - pdf_objects: Optional[List[Union[TextRegion, ImageTextRegion]]] = None, + pdf_objects: Optional[List[TextRegion]] = None, ocr_strategy: str = "auto", extract_tables: bool = False, ) -> LayoutElement: """Creates a LayoutElement from a given layout or image by finding all the text that lies within a given block.""" - if block.text is not None: - # If block text is already populated, we'll assume it's correct - text = block.text - elif extract_tables and isinstance(block, LayoutElement) and block.type == "Table": - text = interprete_table_block(block, image) - elif pdf_objects is not None: - text = aggregate_by_block(block, image, pdf_objects, ocr_strategy) - elif image is not None: - # We don't have anything to go on but the image itself, so we use OCR - text = ocr(block, image) - else: - raise ValueError( - "Got arguments image and layout as None, at least one must be populated to use for " - "text extraction." - ) element = LayoutElement.from_region(block) - element.text = text + element.text = block.extract_text( + objects=pdf_objects, image=image, extract_tables=extract_tables, ocr_strategy=ocr_strategy + ) return element -def aggregate_by_block( - text_region: TextRegion, - image: Optional[Image.Image], - pdf_objects: List[Union[TextRegion, ImageTextRegion]], - ocr_strategy: str = "auto", -) -> str: - """Extracts the text aggregated from the elements of the given layout that lie within the given - block.""" - word_objects = [obj for obj in pdf_objects if isinstance(obj, TextRegion)] - image_objects = [obj for obj in pdf_objects if isinstance(obj, ImageTextRegion)] - if image is not None and needs_ocr(text_region, word_objects, image_objects, ocr_strategy): - text = ocr(text_region, image) - else: - filtered_blocks = [obj for obj in pdf_objects if obj.is_in(text_region, error_margin=5)] - for little_block in filtered_blocks: - if image is not None and needs_ocr( - little_block, word_objects, image_objects, ocr_strategy - ): - little_block.text = ocr(little_block, image) - text = " ".join([x.text for x in filtered_blocks if x.text]) - text = remove_control_characters(text) - return text - - -def interprete_table_block(text_block: TextRegion, image: Image.Image) -> str: - """Extract the contents of a table.""" - tables.load_agent() - if tables.tables_agent is None: - raise RuntimeError("Unable to load table extraction agent.") - padded_block = text_block.pad(12) - cropped_image = image.crop((padded_block.x1, padded_block.y1, padded_block.x2, padded_block.y2)) - return tables.tables_agent.predict(cropped_image) - - -def ocr(text_block: TextRegion, image: Image.Image) -> str: - """Runs a cropped text block image through and OCR agent.""" - logger.debug("Running OCR on text block ...") - tesseract.load_agent() - padded_block = text_block.pad(12) - cropped_image = image.crop((padded_block.x1, padded_block.y1, padded_block.x2, padded_block.y2)) - return tesseract.ocr_agent.detect(cropped_image) - - -def remove_control_characters(text: str) -> str: - """Removes control characters from text.""" - out_text = "".join(c for c in text if unicodedata.category(c)[0] != "C") - return out_text - - def load_pdf( filename: str, x_tolerance: Union[int, float] = 1.5, @@ -370,8 +292,8 @@ def load_pdf( extra_attrs=extra_attrs, split_at_punctuation=split_at_punctuation, ) - word_objs = [ - TextRegion( + word_objs: List[TextRegion] = [ + EmbeddedTextRegion( x1=word["x0"] * dpi / 72, y1=word["top"] * dpi / 72, x2=word["x1"] * dpi / 72, @@ -380,7 +302,7 @@ def load_pdf( ) for word in plumber_words ] - image_objs = [ + image_objs: List[TextRegion] = [ ImageTextRegion( x1=image["x0"] * dpi / 72, y1=image["top"] * dpi / 72, @@ -394,35 +316,3 @@ def load_pdf( images = pdf2image.convert_from_path(filename, dpi=dpi) return layouts, images - - -def needs_ocr( - region: TextRegion, - word_objects: List[TextRegion], - image_objects: List[ImageTextRegion], - ocr_strategy: str, -) -> bool: - """Logic to determine whether ocr is needed to extract text from given region.""" - if ocr_strategy == "force": - return True - elif ocr_strategy == "auto": - # If any image object overlaps with the region of interest, we have hope of getting some - # text from OCR. Otherwise, there's nothing there to find, no need to waste our time with - # OCR. - image_intersects = any(region.intersects(img_obj) for img_obj in image_objects) - if region.text is None: - # If the region has no text check if any images overlap with the region that might - # contain text. - if any(obj.is_in(region) and obj.text is not None for obj in word_objects): - # If there are word objects in the region, we defer to that rather than OCR - return False - else: - return image_intersects - elif cid_ratio(region.text) > 0.5: - # If the region has text, we should only have to OCR if too much of the text is - # uninterpretable. - return image_intersects - else: - return False - else: - return False diff --git a/unstructured_inference/inference/layoutelement.py b/unstructured_inference/inference/layoutelement.py new file mode 100644 index 00000000..9476be67 --- /dev/null +++ b/unstructured_inference/inference/layoutelement.py @@ -0,0 +1,71 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import List, Optional + +from layoutparser.elements.layout import TextBlock +from PIL import Image + +from unstructured_inference.inference.elements import Rectangle, TextRegion +from unstructured_inference.models import tables + + +@dataclass +class LayoutElement(TextRegion): + type: Optional[str] = None + + def extract_text( + self, + objects: Optional[List[TextRegion]], + image: Optional[Image.Image] = None, + extract_tables: bool = False, + ocr_strategy: str = "auto", + ): + """Extracts text contained in region""" + if self.text is not None: + # If block text is already populated, we'll assume it's correct + text = self.text + elif extract_tables and isinstance(self, LayoutElement) and self.type == "Table": + text = interprete_table_block(self, image) + else: + text = super().extract_text( + objects=objects, + image=image, + extract_tables=extract_tables, + ocr_strategy=ocr_strategy, + ) + return text + + def to_dict(self) -> dict: + """Converts the class instance to dictionary form.""" + out_dict = { + "coordinates": self.coordinates, + "text": self.text, + "type": self.type, + } + return out_dict + + @classmethod + def from_region(cls, region: Rectangle): + """Create LayoutElement from superclass.""" + x1, y1, x2, y2 = region.x1, region.y1, region.x2, region.y2 + text = region.text if hasattr(region, "text") else None + type = region.type if hasattr(region, "type") else None + return cls(x1, y1, x2, y2, text, type) + + @classmethod + def from_lp_textblock(cls, textblock: TextBlock): + """Create LayoutElement from layoutparser TextBlock object.""" + x1, y1, x2, y2 = textblock.coordinates + text = textblock.text + type = textblock.type + return cls(x1, y1, x2, y2, text, type) + + +def interprete_table_block(text_block: TextRegion, image: Image.Image) -> str: + """Extract the contents of a table.""" + tables.load_agent() + if tables.tables_agent is None: + raise RuntimeError("Unable to load table extraction agent.") + padded_block = text_block.pad(12) + cropped_image = image.crop((padded_block.x1, padded_block.y1, padded_block.x2, padded_block.y2)) + return tables.tables_agent.predict(cropped_image) diff --git a/unstructured_inference/models/detectron2.py b/unstructured_inference/models/detectron2.py index 19f82e66..ab0e3887 100644 --- a/unstructured_inference/models/detectron2.py +++ b/unstructured_inference/models/detectron2.py @@ -10,7 +10,7 @@ from huggingface_hub import hf_hub_download from unstructured_inference.logger import logger -from unstructured_inference.inference.elements import LayoutElement +from unstructured_inference.inference.layoutelement import LayoutElement from unstructured_inference.models.unstructuredmodel import UnstructuredModel from unstructured_inference.utils import LazyDict, LazyEvaluateInfo diff --git a/unstructured_inference/models/yolox.py b/unstructured_inference/models/yolox.py index cf50d16c..16e62503 100644 --- a/unstructured_inference/models/yolox.py +++ b/unstructured_inference/models/yolox.py @@ -10,7 +10,7 @@ import onnxruntime from typing import List -from unstructured_inference.inference.elements import LayoutElement +from unstructured_inference.inference.layoutelement import LayoutElement from unstructured_inference.models.unstructuredmodel import UnstructuredModel from unstructured_inference.visualize import draw_bounding_boxes from unstructured_inference.utils import LazyDict, LazyEvaluateInfo