Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
## 0.3.3-dev2
## 0.4.0

* Added logic to partition granular elements (words, characters) by proximity
* Text extraction is now delegated to text regions rather than being handled centrally
* Fixed embedded image coordinates being interpreted differently than embedded text coordinates
* Update to how dependencies are being handled
* Update detectron2 version
Expand Down
120 changes: 78 additions & 42 deletions test_unstructured_inference/inference/test_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from PIL import Image

import unstructured_inference.inference.layout as layout
import unstructured_inference.inference.elements as elements
import unstructured_inference.models.base as models
import unstructured_inference.models.detectron2 as detectron2
import unstructured_inference.models.tesseract as tesseract
Expand All @@ -20,9 +21,9 @@ def mock_image():

@pytest.fixture
def mock_page_layout():
text_block = layout.TextRegion(2, 4, 6, 8, text="A very repetitive narrative. " * 10)
text_block = layout.EmbeddedTextRegion(2, 4, 6, 8, text="A very repetitive narrative. " * 10)

title_block = layout.TextRegion(1, 2, 3, 4, text="A Catchy Title")
title_block = layout.EmbeddedTextRegion(1, 2, 3, 4, text="A Catchy Title")

return [text_block, title_block]

Expand All @@ -49,7 +50,7 @@ def detect(self, *args):
image = Image.fromarray(np.random.randint(12, 24, (40, 40)), mode="RGB")
text_block = layout.TextRegion(1, 2, 3, 4, text=None)

assert layout.ocr(text_block, image=image) == mock_text
assert elements.ocr(text_block, image=image) == mock_text


class MockLayoutModel:
Expand All @@ -69,12 +70,12 @@ def test_get_page_elements(monkeypatch, mock_page_layout):
number=0, image=image, layout=mock_page_layout, model=MockLayoutModel(mock_page_layout)
)

elements = page.get_elements(inplace=False)
elements = page.get_elements_with_model(inplace=False)

assert str(elements[0]) == "A Catchy Title"
assert str(elements[1]).startswith("A very repetitive narrative.")

page.get_elements(inplace=True)
page.get_elements_with_model(inplace=True)
assert elements == page.elements


Expand All @@ -95,13 +96,13 @@ def test_get_page_elements_with_ocr(monkeypatch):
doc_layout = [text_block, image_block]

monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)
monkeypatch.setattr(layout, "ocr", lambda *args: "An Even Catchier Title")
monkeypatch.setattr(elements, "ocr", lambda *args: "An Even Catchier Title")

image = Image.fromarray(np.random.randint(12, 14, size=(40, 10, 3)), mode="RGB")
page = layout.PageLayout(
number=0, image=image, layout=doc_layout, model=MockLayoutModel(doc_layout)
)
page.get_elements()
page.get_elements_with_model()

assert str(page) == "\n\nAn Even Catchier Title"

Expand Down Expand Up @@ -174,7 +175,7 @@ def tolist(self):
return [1, 2, 3, 4]


class MockTextRegion(layout.TextRegion):
class MockEmbeddedTextRegion(layout.EmbeddedTextRegion):
def __init__(self, type=None, text=None, ocr_text=None):
self.type = type
self.text = text
Expand All @@ -193,7 +194,7 @@ def __init__(self, layout=None, model=None, ocr_strategy="auto", extract_tables=
self.ocr_strategy = ocr_strategy
self.extract_tables = extract_tables

def ocr(self, text_block: MockTextRegion):
def ocr(self, text_block: MockEmbeddedTextRegion):
return text_block.ocr_text


Expand All @@ -202,15 +203,15 @@ def ocr(self, text_block: MockTextRegion):
[("base", 0.0), ("", 0.0), ("(cid:2)", 1.0), ("(cid:1)a", 0.5), ("c(cid:1)ab", 0.25)],
)
def test_cid_ratio(text, expected):
assert layout.cid_ratio(text) == expected
assert elements.cid_ratio(text) == expected


@pytest.mark.parametrize(
"text, expected",
[("base", False), ("(cid:2)", True), ("(cid:1234567890)", True), ("jkl;(cid:12)asdf", True)],
)
def test_is_cid_present(text, expected):
assert layout.is_cid_present(text) == expected
assert elements.is_cid_present(text) == expected


class MockLayout:
Expand Down Expand Up @@ -241,7 +242,7 @@ def filter_by(self, *args, **kwargs):
],
)
def test_get_element_from_block(block_text, layout_texts, mock_image, expected_text):
with patch("unstructured_inference.inference.layout.ocr", return_value="ocr"):
with patch("unstructured_inference.inference.elements.ocr", return_value="ocr"):
block = layout.TextRegion(0, 0, 10, 10, text=block_text)
captured_layout = [
layout.TextRegion(i + 1, i + 1, i + 2, i + 2, text=text)
Expand All @@ -263,7 +264,7 @@ def test_from_image_file(monkeypatch, mock_page_layout, filetype):
def mock_get_elements(self, *args, **kwargs):
self.elements = [mock_page_layout]

monkeypatch.setattr(layout.PageLayout, "get_elements", mock_get_elements)
monkeypatch.setattr(layout.PageLayout, "get_elements_with_model", mock_get_elements)
elements = (
layout.DocumentLayout.from_image_file(f"sample-docs/loremipsum.{filetype}")
.pages[0]
Expand Down Expand Up @@ -301,12 +302,12 @@ def test_get_elements_from_layout(mock_page_layout, idx):
@pytest.mark.parametrize(
"fixed_layouts, called_method, not_called_method",
[
([MockLayout()], "get_elements_from_layout", "get_elements"),
(None, "get_elements", "get_elements_from_layout"),
([MockLayout()], "get_elements_from_layout", "get_elements_with_model"),
(None, "get_elements_with_model", "get_elements_from_layout"),
],
)
def test_from_file_fixed_layout(fixed_layouts, called_method, not_called_method):
with patch.object(layout.PageLayout, "get_elements", return_value=[]), patch.object(
with patch.object(layout.PageLayout, "get_elements_with_model", return_value=[]), patch.object(
layout.PageLayout, "get_elements_from_layout", return_value=[]
):
layout.DocumentLayout.from_file("sample-docs/loremipsum.pdf", fixed_layouts=fixed_layouts)
Expand All @@ -323,70 +324,105 @@ def test_invalid_ocr_strategy_raises(mock_image):
("text", "expected"), [("a\ts\x0cd\nfas\fd\rf\b", "asdfasdf"), ("\"'\\", "\"'\\")]
)
def test_remove_control_characters(text, expected):
assert layout.remove_control_characters(text) == expected
assert elements.remove_control_characters(text) == expected


no_text_region = layout.TextRegion(0, 0, 100, 100)
text_region = layout.TextRegion(0, 0, 100, 100, text="test")
cid_text_region = layout.TextRegion(0, 0, 100, 100, text="(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)")
no_text_region = layout.EmbeddedTextRegion(0, 0, 100, 100)
text_region = layout.EmbeddedTextRegion(0, 0, 100, 100, text="test")
cid_text_region = layout.EmbeddedTextRegion(
0, 0, 100, 100, text="(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)"
)
overlapping_rect = layout.ImageTextRegion(50, 50, 150, 150)
nonoverlapping_rect = layout.ImageTextRegion(150, 150, 200, 200)
populated_text_region = layout.TextRegion(50, 50, 60, 60, text="test")
unpopulated_text_region = layout.TextRegion(50, 50, 60, 60, text=None)
populated_text_region = layout.EmbeddedTextRegion(50, 50, 60, 60, text="test")
unpopulated_text_region = layout.EmbeddedTextRegion(50, 50, 60, 60, text=None)


@pytest.mark.parametrize(
("region", "text_objects", "image_objects", "ocr_strategy", "expected"),
("region", "objects", "ocr_strategy", "expected"),
[
(no_text_region, [], [nonoverlapping_rect], "auto", False),
(no_text_region, [], [overlapping_rect], "auto", True),
(no_text_region, [], [], "auto", False),
(no_text_region, [populated_text_region], [nonoverlapping_rect], "auto", False),
(no_text_region, [populated_text_region], [overlapping_rect], "auto", False),
(no_text_region, [populated_text_region], [], "auto", False),
(no_text_region, [unpopulated_text_region], [nonoverlapping_rect], "auto", False),
(no_text_region, [unpopulated_text_region], [overlapping_rect], "auto", True),
(no_text_region, [unpopulated_text_region], [], "auto", False),
(no_text_region, [nonoverlapping_rect], "auto", False),
(no_text_region, [overlapping_rect], "auto", True),
(no_text_region, [], "auto", False),
(no_text_region, [populated_text_region, nonoverlapping_rect], "auto", False),
(no_text_region, [populated_text_region, overlapping_rect], "auto", False),
(no_text_region, [populated_text_region], "auto", False),
(no_text_region, [unpopulated_text_region, nonoverlapping_rect], "auto", False),
(no_text_region, [unpopulated_text_region, overlapping_rect], "auto", True),
(no_text_region, [unpopulated_text_region], "auto", False),
*list(
product(
[text_region],
[[], [populated_text_region], [unpopulated_text_region]],
[[], [nonoverlapping_rect], [overlapping_rect]],
[
[],
[populated_text_region],
[unpopulated_text_region],
[nonoverlapping_rect],
[overlapping_rect],
[populated_text_region, nonoverlapping_rect],
[populated_text_region, overlapping_rect],
[unpopulated_text_region, nonoverlapping_rect],
[unpopulated_text_region, overlapping_rect],
],
["auto"],
[False],
)
),
*list(
product(
[cid_text_region],
[[], [populated_text_region], [unpopulated_text_region]],
[[overlapping_rect]],
[
[],
[populated_text_region],
[unpopulated_text_region],
[overlapping_rect],
[populated_text_region, overlapping_rect],
[unpopulated_text_region, overlapping_rect],
],
["auto"],
[True],
)
),
*list(
product(
[no_text_region, text_region, cid_text_region],
[[], [populated_text_region], [unpopulated_text_region]],
[[], [nonoverlapping_rect], [overlapping_rect]],
[
[],
[populated_text_region],
[unpopulated_text_region],
[nonoverlapping_rect],
[overlapping_rect],
[populated_text_region, nonoverlapping_rect],
[populated_text_region, overlapping_rect],
[unpopulated_text_region, nonoverlapping_rect],
[unpopulated_text_region, overlapping_rect],
],
["force"],
[True],
)
),
*list(
product(
[no_text_region, text_region, cid_text_region],
[[], [populated_text_region], [unpopulated_text_region]],
[[], [nonoverlapping_rect], [overlapping_rect]],
[
[],
[populated_text_region],
[unpopulated_text_region],
[nonoverlapping_rect],
[overlapping_rect],
[populated_text_region, nonoverlapping_rect],
[populated_text_region, overlapping_rect],
[unpopulated_text_region, nonoverlapping_rect],
[unpopulated_text_region, overlapping_rect],
],
["never"],
[False],
)
),
],
)
def test_ocr_image(region, text_objects, image_objects, ocr_strategy, expected):
assert layout.needs_ocr(region, text_objects, image_objects, ocr_strategy) is expected
def test_ocr_image(region, objects, ocr_strategy, expected):
assert elements.needs_ocr(region, objects, ocr_strategy) is expected


def test_load_pdf():
Expand Down
84 changes: 84 additions & 0 deletions test_unstructured_inference/test_elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from random import randint
from unstructured_inference.inference import elements
from unstructured_inference.inference.layout import load_pdf


def intersect_brute(rect1, rect2):
return any(
(rect2.x1 <= x <= rect2.x2) and (rect2.y1 <= y <= rect2.y2)
for x in range(rect1.x1, rect1.x2 + 1)
for y in range(rect1.y1, rect1.y2 + 1)
)


def rand_rect(size=10):
x1 = randint(0, 30 - size)
y1 = randint(0, 30 - size)
return elements.Rectangle(x1, y1, x1 + size, y1 + size)


def test_intersects_overlap():
for _ in range(1000):
rect1 = rand_rect()
rect2 = rand_rect()
assert intersect_brute(rect1, rect2) == rect1.intersects(rect2) == rect2.intersects(rect1)


def test_intersects_subset():
for _ in range(1000):
rect1 = rand_rect()
rect2 = rand_rect(20)
assert intersect_brute(rect1, rect2) == rect1.intersects(rect2) == rect2.intersects(rect1)


def test_intersection_of_lots_of_rects():
for _ in range(1000):
n_rects = 10
rects = [rand_rect(6) for _ in range(n_rects)]
intersection_mtx = elements.intersections(*rects)
for i in range(n_rects):
for j in range(n_rects):
assert (
intersect_brute(rects[i], rects[j])
== intersection_mtx[i, j]
== intersection_mtx[j, i]
)


def test_rectangle_width_height():
for _ in range(1000):
x1 = randint(0, 50)
x2 = randint(x1 + 1, 100)
y1 = randint(0, 50)
y2 = randint(y1 + 1, 100)
rect = elements.Rectangle(x1, y1, x2, y2)
assert rect.width == x2 - x1
assert rect.height == y2 - y1


def test_minimal_containing_rect():
for _ in range(1000):
rect1 = rand_rect()
rect2 = rand_rect()
big_rect = elements.minimal_containing_region(rect1, rect2)
for decrease_attr in ["x1", "y1", "x2", "y2"]:
almost_as_big_rect = rand_rect()
mod = 1 if decrease_attr.endswith("1") else -1
for attr in ["x1", "y1", "x2", "y2"]:
if attr == decrease_attr:
setattr(almost_as_big_rect, attr, getattr(big_rect, attr) + mod)
else:
setattr(almost_as_big_rect, attr, getattr(big_rect, attr))
assert not rect1.is_in(almost_as_big_rect) or not rect2.is_in(almost_as_big_rect)

assert rect1.is_in(big_rect)
assert rect2.is_in(big_rect)


def test_partition_groups_from_regions():
words, _ = load_pdf("sample-docs/layout-parser-paper.pdf")
groups = elements.partition_groups_from_regions(words[0])
assert len(groups) == 9
sorted_groups = sorted(groups, key=lambda group: group[0].y1)
text = "".join([el.text for el in sorted_groups[-1]])
assert text.startswith("Deep")
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.3-dev2" # pragma: no cover
__version__ = "0.4.0" # pragma: no cover
Loading