Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
## 0.2.1-dev0
## 0.2.1-dev1

* Refactor to facilitate local inference
* Removes BasicConfig from logger configuration
* Implement auto model downloading

Expand Down
81 changes: 71 additions & 10 deletions test_unstructured_inference/inference/test_layout.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import pytest
from unittest.mock import patch
from unittest.mock import patch, mock_open
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't know about mock_open before!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Me either!


import layoutparser as lp
from layoutparser.elements import Layout, Rectangle, TextBlock
import numpy as np
from PIL import Image

from unstructured_inference.inference.layout import DocumentLayout, PageLayout
import unstructured_inference.inference.layout as layout
import unstructured_inference.models as models

import unstructured_inference.models.detectron2 as detectron2
import unstructured_inference.models.tesseract as tesseract

Expand All @@ -28,7 +30,7 @@ def mock_page_layout():


def test_pdf_page_converts_images_to_array(mock_image):
page = PageLayout(number=0, image=mock_image, layout=Layout())
page = layout.PageLayout(number=0, image=mock_image, layout=Layout())
assert page.image_array is None

image_array = page._get_image_array()
Expand All @@ -47,7 +49,7 @@ def detect(self, *args):
monkeypatch.setattr(tesseract, "is_pytesseract_available", lambda *args: True)

image = np.random.randint(12, 24, (40, 40))
page = PageLayout(number=0, image=image, layout=Layout())
page = layout.PageLayout(number=0, image=image, layout=Layout())
rectangle = Rectangle(1, 2, 3, 4)
text_block = TextBlock(rectangle, text=None)

Expand All @@ -67,7 +69,7 @@ def test_get_page_elements(monkeypatch, mock_page_layout):
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)

image = np.random.randint(12, 24, (40, 40))
page = PageLayout(number=0, image=image, layout=mock_page_layout)
page = layout.PageLayout(number=0, image=image, layout=mock_page_layout)

elements = page.get_elements(inplace=False)

Expand All @@ -79,17 +81,17 @@ def test_get_page_elements(monkeypatch, mock_page_layout):


def test_get_page_elements_with_ocr(monkeypatch):
monkeypatch.setattr(PageLayout, "ocr", lambda *args: "An Even Catchier Title")
monkeypatch.setattr(layout.PageLayout, "ocr", lambda *args: "An Even Catchier Title")

rectangle = Rectangle(2, 4, 6, 8)
text_block = TextBlock(rectangle, text=None, type="Title")
layout = Layout([text_block])
doc_layout = Layout([text_block])

monkeypatch.setattr(detectron2, "load_default_model", lambda: MockLayoutModel(layout))
monkeypatch.setattr(detectron2, "load_default_model", lambda: MockLayoutModel(doc_layout))
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)

image = np.random.randint(12, 24, (40, 40))
page = PageLayout(number=0, image=image, layout=layout)
page = layout.PageLayout(number=0, image=image, layout=doc_layout)
page.get_elements()

assert str(page) == "An Even Catchier Title"
Expand All @@ -105,7 +107,7 @@ def test_read_pdf(monkeypatch, mock_page_layout):
monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True)

with patch.object(lp, "load_pdf", return_value=(layouts, images)):
doc = DocumentLayout.from_file("fake-file.pdf")
doc = layout.DocumentLayout.from_file("fake-file.pdf")

assert str(doc).startswith("A Catchy Title")
assert str(doc).count("A Catchy Title") == 2 # Once for each page
Expand All @@ -115,3 +117,62 @@ def test_read_pdf(monkeypatch, mock_page_layout):

pages = doc.pages
assert str(doc) == "\n\n".join([str(page) for page in pages])


@pytest.mark.parametrize("model_name", [None, "checkbox", "fake"])
def test_process_data_with_model(monkeypatch, mock_page_layout, model_name):
monkeypatch.setattr(models, "get_model", lambda x: MockLayoutModel(mock_page_layout))
monkeypatch.setattr(
layout.DocumentLayout,
"from_file",
lambda *args, **kwargs: layout.DocumentLayout.from_pages([]),
)
monkeypatch.setattr(
models, "load_model", lambda *args, **kwargs: MockLayoutModel(mock_page_layout)
)
monkeypatch.setattr(
models,
"_get_model_loading_info",
lambda *args, **kwargs: (
"fake-binary-path",
"fake-config-path",
{0: "Unchecked", 1: "Checked"},
),
)
with patch("builtins.open", mock_open(read_data=b"000000")):
assert layout.process_data_with_model(open(""), model_name=model_name)


def test_process_data_with_model_raises_on_invalid_model_name():
with patch("builtins.open", mock_open(read_data=b"000000")):
with pytest.raises(models.UnknownModelException):
layout.process_data_with_model(open(""), model_name="fake")


@pytest.mark.parametrize("model_name", [None, "checkbox"])
def test_process_file_with_model(monkeypatch, mock_page_layout, model_name):
monkeypatch.setattr(models, "get_model", lambda x: MockLayoutModel(mock_page_layout))
monkeypatch.setattr(
layout.DocumentLayout,
"from_file",
lambda *args, **kwargs: layout.DocumentLayout.from_pages([]),
)
monkeypatch.setattr(
models, "load_model", lambda *args, **kwargs: MockLayoutModel(mock_page_layout)
)
monkeypatch.setattr(
models,
"_get_model_loading_info",
lambda *args, **kwargs: (
"fake-binary-path",
"fake-config-path",
{0: "Unchecked", 1: "Checked"},
),
)
filename = ""
assert layout.process_file_with_model(filename, model_name=model_name)


def test_process_file_with_model_raises_on_invalid_model_name():
with pytest.raises(models.UnknownModelException):
layout.process_file_with_model("", model_name="fake")
2 changes: 1 addition & 1 deletion test_unstructured_inference/models/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ def test_get_model(monkeypatch):


def test_raises_invalid_model():
with pytest.raises(ValueError):
with pytest.raises(models.UnknownModelException):
models.get_model("fake_model")
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.1-dev0" # pragma: no cover
__version__ = "0.2.1-dev1" # pragma: no cover
19 changes: 6 additions & 13 deletions unstructured_inference/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from fastapi import FastAPI, File, status, Request, UploadFile, Form, HTTPException
from unstructured_inference.inference.layout import DocumentLayout
from unstructured_inference.models import get_model
from unstructured_inference.inference.layout import process_data_with_model
from unstructured_inference.models import UnknownModelException
from typing import List
import tempfile

app = FastAPI()

Expand All @@ -15,16 +14,10 @@ async def layout_parsing_pdf(
include_elems: List[str] = Form(default=ALL_ELEMS),
model: str = Form(default=None),
):
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file.write(file.file.read())
if model is None:
layout = DocumentLayout.from_file(tmp_file.name)
else:
try:
detector = get_model(model)
except ValueError as e:
raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, str(e))
layout = DocumentLayout.from_file(tmp_file.name, model=detector)
try:
layout = process_data_with_model(file.file, model)
except UnknownModelException as e:
raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, str(e))
pages_layout = [
{
"number": page.number,
Expand Down
22 changes: 21 additions & 1 deletion unstructured_inference/inference/layout.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import tempfile
from typing import List, Optional, Tuple, Union, BinaryIO

import layoutparser as lp
from layoutparser.models.detectron2.layoutmodel import Detectron2LayoutModel
Expand All @@ -10,6 +11,7 @@
from unstructured_inference.logger import logger
import unstructured_inference.models.tesseract as tesseract
import unstructured_inference.models.detectron2 as detectron2
from unstructured_inference.models import get_model


@dataclass
Expand Down Expand Up @@ -136,3 +138,21 @@ def _get_image_array(self) -> Union[np.ndarray, None]:
if self.image_array is None:
self.image_array = np.array(self.image)
return self.image_array


def process_data_with_model(data: BinaryIO, model_name: str) -> DocumentLayout:
"""Processes pdf file in the form of a file handler (supporting a read method) into a
DocumentLayout by using a model identified by model_name."""
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_file.write(data.read())
layout = process_file_with_model(tmp_file.name, model_name)

return layout


def process_file_with_model(filename: str, model_name: str) -> DocumentLayout:
"""Processes pdf file with name filename into a DocumentLayout by using a model identified by
model_name."""
model = None if model_name is None else get_model(model_name)
layout = DocumentLayout.from_file(filename, model=model)
return layout
8 changes: 7 additions & 1 deletion unstructured_inference/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,11 @@ def _get_model_loading_info(model: str) -> Tuple[str, str, Dict[int, str]]:
config_path = hf_hub_download(repo_id, config_fn)
label_map = {0: "Unchecked", 1: "Checked"}
else:
raise ValueError(f"Unknown model type: {model}")
raise UnknownModelException(f"Unknown model type: {model}")
return model_path, config_path, label_map


class UnknownModelException(Exception):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:-)

"""Exception for the case where a model is called for with an unrecognized identifier."""

pass