diff --git a/CHANGELOG.md b/CHANGELOG.md index ccf2ec03..684d9563 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ -## 0.2.1-dev0 +## 0.2.1-dev1 +* Refactor to facilitate local inference * Removes BasicConfig from logger configuration * Implement auto model downloading diff --git a/test_unstructured_inference/inference/test_layout.py b/test_unstructured_inference/inference/test_layout.py index 404aa921..05c68736 100644 --- a/test_unstructured_inference/inference/test_layout.py +++ b/test_unstructured_inference/inference/test_layout.py @@ -1,12 +1,14 @@ import pytest -from unittest.mock import patch +from unittest.mock import patch, mock_open import layoutparser as lp from layoutparser.elements import Layout, Rectangle, TextBlock import numpy as np from PIL import Image -from unstructured_inference.inference.layout import DocumentLayout, PageLayout +import unstructured_inference.inference.layout as layout +import unstructured_inference.models as models + import unstructured_inference.models.detectron2 as detectron2 import unstructured_inference.models.tesseract as tesseract @@ -28,7 +30,7 @@ def mock_page_layout(): def test_pdf_page_converts_images_to_array(mock_image): - page = PageLayout(number=0, image=mock_image, layout=Layout()) + page = layout.PageLayout(number=0, image=mock_image, layout=Layout()) assert page.image_array is None image_array = page._get_image_array() @@ -47,7 +49,7 @@ def detect(self, *args): monkeypatch.setattr(tesseract, "is_pytesseract_available", lambda *args: True) image = np.random.randint(12, 24, (40, 40)) - page = PageLayout(number=0, image=image, layout=Layout()) + page = layout.PageLayout(number=0, image=image, layout=Layout()) rectangle = Rectangle(1, 2, 3, 4) text_block = TextBlock(rectangle, text=None) @@ -67,7 +69,7 @@ def test_get_page_elements(monkeypatch, mock_page_layout): monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True) image = np.random.randint(12, 24, (40, 40)) - page = PageLayout(number=0, image=image, layout=mock_page_layout) + page = layout.PageLayout(number=0, image=image, layout=mock_page_layout) elements = page.get_elements(inplace=False) @@ -79,17 +81,17 @@ def test_get_page_elements(monkeypatch, mock_page_layout): def test_get_page_elements_with_ocr(monkeypatch): - monkeypatch.setattr(PageLayout, "ocr", lambda *args: "An Even Catchier Title") + monkeypatch.setattr(layout.PageLayout, "ocr", lambda *args: "An Even Catchier Title") rectangle = Rectangle(2, 4, 6, 8) text_block = TextBlock(rectangle, text=None, type="Title") - layout = Layout([text_block]) + doc_layout = Layout([text_block]) - monkeypatch.setattr(detectron2, "load_default_model", lambda: MockLayoutModel(layout)) + monkeypatch.setattr(detectron2, "load_default_model", lambda: MockLayoutModel(doc_layout)) monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True) image = np.random.randint(12, 24, (40, 40)) - page = PageLayout(number=0, image=image, layout=layout) + page = layout.PageLayout(number=0, image=image, layout=doc_layout) page.get_elements() assert str(page) == "An Even Catchier Title" @@ -105,7 +107,7 @@ def test_read_pdf(monkeypatch, mock_page_layout): monkeypatch.setattr(detectron2, "is_detectron2_available", lambda *args: True) with patch.object(lp, "load_pdf", return_value=(layouts, images)): - doc = DocumentLayout.from_file("fake-file.pdf") + doc = layout.DocumentLayout.from_file("fake-file.pdf") assert str(doc).startswith("A Catchy Title") assert str(doc).count("A Catchy Title") == 2 # Once for each page @@ -115,3 +117,62 @@ def test_read_pdf(monkeypatch, mock_page_layout): pages = doc.pages assert str(doc) == "\n\n".join([str(page) for page in pages]) + + +@pytest.mark.parametrize("model_name", [None, "checkbox", "fake"]) +def test_process_data_with_model(monkeypatch, mock_page_layout, model_name): + monkeypatch.setattr(models, "get_model", lambda x: MockLayoutModel(mock_page_layout)) + monkeypatch.setattr( + layout.DocumentLayout, + "from_file", + lambda *args, **kwargs: layout.DocumentLayout.from_pages([]), + ) + monkeypatch.setattr( + models, "load_model", lambda *args, **kwargs: MockLayoutModel(mock_page_layout) + ) + monkeypatch.setattr( + models, + "_get_model_loading_info", + lambda *args, **kwargs: ( + "fake-binary-path", + "fake-config-path", + {0: "Unchecked", 1: "Checked"}, + ), + ) + with patch("builtins.open", mock_open(read_data=b"000000")): + assert layout.process_data_with_model(open(""), model_name=model_name) + + +def test_process_data_with_model_raises_on_invalid_model_name(): + with patch("builtins.open", mock_open(read_data=b"000000")): + with pytest.raises(models.UnknownModelException): + layout.process_data_with_model(open(""), model_name="fake") + + +@pytest.mark.parametrize("model_name", [None, "checkbox"]) +def test_process_file_with_model(monkeypatch, mock_page_layout, model_name): + monkeypatch.setattr(models, "get_model", lambda x: MockLayoutModel(mock_page_layout)) + monkeypatch.setattr( + layout.DocumentLayout, + "from_file", + lambda *args, **kwargs: layout.DocumentLayout.from_pages([]), + ) + monkeypatch.setattr( + models, "load_model", lambda *args, **kwargs: MockLayoutModel(mock_page_layout) + ) + monkeypatch.setattr( + models, + "_get_model_loading_info", + lambda *args, **kwargs: ( + "fake-binary-path", + "fake-config-path", + {0: "Unchecked", 1: "Checked"}, + ), + ) + filename = "" + assert layout.process_file_with_model(filename, model_name=model_name) + + +def test_process_file_with_model_raises_on_invalid_model_name(): + with pytest.raises(models.UnknownModelException): + layout.process_file_with_model("", model_name="fake") diff --git a/test_unstructured_inference/models/test_model.py b/test_unstructured_inference/models/test_model.py index d1d393f0..e5cce400 100644 --- a/test_unstructured_inference/models/test_model.py +++ b/test_unstructured_inference/models/test_model.py @@ -24,5 +24,5 @@ def test_get_model(monkeypatch): def test_raises_invalid_model(): - with pytest.raises(ValueError): + with pytest.raises(models.UnknownModelException): models.get_model("fake_model") diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index aefc4556..6aa693e8 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.2.1-dev0" # pragma: no cover +__version__ = "0.2.1-dev1" # pragma: no cover diff --git a/unstructured_inference/api.py b/unstructured_inference/api.py index 93b52ab7..fc237d84 100644 --- a/unstructured_inference/api.py +++ b/unstructured_inference/api.py @@ -1,8 +1,7 @@ from fastapi import FastAPI, File, status, Request, UploadFile, Form, HTTPException -from unstructured_inference.inference.layout import DocumentLayout -from unstructured_inference.models import get_model +from unstructured_inference.inference.layout import process_data_with_model +from unstructured_inference.models import UnknownModelException from typing import List -import tempfile app = FastAPI() @@ -15,16 +14,10 @@ async def layout_parsing_pdf( include_elems: List[str] = Form(default=ALL_ELEMS), model: str = Form(default=None), ): - with tempfile.NamedTemporaryFile() as tmp_file: - tmp_file.write(file.file.read()) - if model is None: - layout = DocumentLayout.from_file(tmp_file.name) - else: - try: - detector = get_model(model) - except ValueError as e: - raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, str(e)) - layout = DocumentLayout.from_file(tmp_file.name, model=detector) + try: + layout = process_data_with_model(file.file, model) + except UnknownModelException as e: + raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, str(e)) pages_layout = [ { "number": page.number, diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py index 63fe663b..f3e37c85 100644 --- a/unstructured_inference/inference/layout.py +++ b/unstructured_inference/inference/layout.py @@ -1,6 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +import tempfile +from typing import List, Optional, Tuple, Union, BinaryIO import layoutparser as lp from layoutparser.models.detectron2.layoutmodel import Detectron2LayoutModel @@ -10,6 +11,7 @@ from unstructured_inference.logger import logger import unstructured_inference.models.tesseract as tesseract import unstructured_inference.models.detectron2 as detectron2 +from unstructured_inference.models import get_model @dataclass @@ -136,3 +138,21 @@ def _get_image_array(self) -> Union[np.ndarray, None]: if self.image_array is None: self.image_array = np.array(self.image) return self.image_array + + +def process_data_with_model(data: BinaryIO, model_name: str) -> DocumentLayout: + """Processes pdf file in the form of a file handler (supporting a read method) into a + DocumentLayout by using a model identified by model_name.""" + with tempfile.NamedTemporaryFile() as tmp_file: + tmp_file.write(data.read()) + layout = process_file_with_model(tmp_file.name, model_name) + + return layout + + +def process_file_with_model(filename: str, model_name: str) -> DocumentLayout: + """Processes pdf file with name filename into a DocumentLayout by using a model identified by + model_name.""" + model = None if model_name is None else get_model(model_name) + layout = DocumentLayout.from_file(filename, model=model) + return layout diff --git a/unstructured_inference/models/__init__.py b/unstructured_inference/models/__init__.py index 57d9554e..b08404fe 100644 --- a/unstructured_inference/models/__init__.py +++ b/unstructured_inference/models/__init__.py @@ -24,5 +24,11 @@ def _get_model_loading_info(model: str) -> Tuple[str, str, Dict[int, str]]: config_path = hf_hub_download(repo_id, config_fn) label_map = {0: "Unchecked", 1: "Checked"} else: - raise ValueError(f"Unknown model type: {model}") + raise UnknownModelException(f"Unknown model type: {model}") return model_path, config_path, label_map + + +class UnknownModelException(Exception): + """Exception for the case where a model is called for with an unrecognized identifier.""" + + pass