diff --git a/CHANGELOG.md b/CHANGELOG.md index d1add3a2..f6e59d30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,6 @@ -## 0.2.13-dev0 +## 0.2.13 +* Add table processing * Change OCR logic to be aware of PDF image elements ## 0.2.12 diff --git a/Makefile b/Makefile index 09c498f5..3ac85a38 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,7 @@ install-base: install-base-pip-packages install: install-base-pip-packages install-dev install-detectron2 install-test .PHONY: install-ci -install-ci: install-base-pip-packages install-test +install-ci: install-base-pip-packages install-test install-paddleocr .PHONY: install-base-pip-packages install-base-pip-packages: @@ -31,6 +31,10 @@ install-base-pip-packages: install-detectron2: pip install "detectron2@git+https://github.com/facebookresearch/detectron2.git@78d5b4f335005091fe0364ce4775d711ec93566e" +.PHONY: install-paddleocr +install-paddleocr: + pip install "unstructured.PaddleOCR" + .PHONY: install-test install-test: pip install -r requirements/test.txt diff --git a/README.md b/README.md index fb4606b6..571b9588 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,17 @@ Windows is not officially supported by Detectron2, but some users are able to in See discussion [here](https://layout-parser.github.io/tutorials/installation#for-windows-users) for tips on installing Detectron2 on Windows. +### PaddleOCR + +[PaddleOCR](https://github.com/Unstructured-IO/unstructured.PaddleOCR) is required for table processing for `x86_64` architectures. +It should not be installed under MacOS with Apple Silicon cpu. + +PaddleOCR should be installed using the following instructions. + +```shell +pip install "unstructured.PaddleOCR" +``` + ### Repository To install the repository for development, clone the repo and run `make install` to install dependencies. diff --git a/sample-docs/example_table.jpg b/sample-docs/example_table.jpg new file mode 100644 index 00000000..c17ce7bf Binary files /dev/null and b/sample-docs/example_table.jpg differ diff --git a/setup.cfg b/setup.cfg index 78525135..dd54ef0d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ license_files = LICENSE.md [flake8] max-line-length = 100 -ignore = D100, D101, D104, D105, D107, D2, D4 +extend-ignore = D100, D101, D104, D105, D107, D2, D4 per-file-ignores = test_*/**: D diff --git a/setup.py b/setup.py index bd19583f..9a32d82f 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ limitations under the License. """ from setuptools import setup, find_packages +from platform import machine from unstructured_inference.__version__ import __version__ @@ -60,5 +61,5 @@ "onnxruntime", "transformers", ], - extras_require={}, + extras_require={"paddle-ocr": "unstructured.PaddleOCR"}, ) diff --git a/test_unstructured_inference/inference/test_layout.py b/test_unstructured_inference/inference/test_layout.py index 6e322a39..4b4b4d7f 100644 --- a/test_unstructured_inference/inference/test_layout.py +++ b/test_unstructured_inference/inference/test_layout.py @@ -186,11 +186,12 @@ def points(self): class MockPageLayout(layout.PageLayout): - def __init__(self, layout=None, model=None, ocr_strategy="auto"): + def __init__(self, layout=None, model=None, ocr_strategy="auto", extract_tables=False): self.image = None self.layout = layout self.model = model self.ocr_strategy = ocr_strategy + self.extract_tables = extract_tables def ocr(self, text_block: MockTextRegion): return text_block.ocr_text diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py new file mode 100644 index 00000000..68e29cb5 --- /dev/null +++ b/test_unstructured_inference/models/test_tables.py @@ -0,0 +1,569 @@ +import pytest +from unittest.mock import patch + +from transformers.models.table_transformer.modeling_table_transformer import TableTransformerDecoder + +import unstructured_inference.models.tables as tables +import unstructured_inference.models.table_postprocess as postprocess + + +@pytest.mark.parametrize( + "model_path", + [ + ("invalid_table_path"), + ("incorrect_table_path"), + ], +) +def test_load_table_model_raises_when_not_available(model_path): + with pytest.raises(ImportError): + table_model = tables.UnstructuredTableTransformerModel() + table_model.initialize(model=model_path) + + +@pytest.mark.parametrize( + "model_path", + [ + "microsoft/table-transformer-structure-recognition", + ], +) +def test_load_donut_model(model_path): + table_model = tables.UnstructuredTableTransformerModel() + table_model.initialize(model=model_path) + assert type(table_model.model.model.decoder) == TableTransformerDecoder + + +@pytest.fixture +def sample_table_transcript(platform_type): + if platform_type == "x86_64": + out = ( + '<" + 'td colspan="2" rowspan="2">This EXAMPLE event includes services like: Primary care ' + "physician office visits (including disease education) Diagnostic tests (blood work) " + 'Prescription drugs Durable medical equipment (glucose meter)<" + 'td>The total Joe would pay is
About these Coverage Examples:
' + 'This is not a cost estimator. Treatments shown are just examples ' + "of how this plan might cover medical care. Your actual costs will be different " + "depending on the actual care you receive, the prices your providers charge, and many " + "other factors. Focus on the cost sharing amounts (deductibles, copayments and " + "coinsurance) and excluded services under the plan. Use this information to compare " + "the portion of costs you might pay under different health plans. Please note these " + 'coverage examples are based on self-only coverage
' + "Peg is Having a Baby (9 months of in-network pre-natal care and a hospital delivery)Managing Joe's type 2 Diabetes (a year of routine in-network care of a well- " + 'controlled condition)Mia\'s Simple Fracture (in-network ' + 'emergency room visit and follow up care)
The plan\'s ' + "overall deductible $750 Specialist copayment $50 Hospital (facility) coinsurance 10" + "% Other coinsurance 10%The plan's overall deductible Specialist copayment " + r"Hospital (facility) coinsurance Other coinsurance$750 $50 10% 10%" + "The plan's overall deductible Specialist copayment Hospital (facility) coinsurance " + r'Other coinsurance$750 $50 10% 10%
' + "This EXAMPLE event includes services like: Specialist office visits (prenatal care) " + "Childbirth/Delivery Professional Services Childbirth/Delivery Facility ServicesThis EXAMPLE event includes services like: Emergency room care (including ' + "medical Diagnostic test (x-ray) Durable medical equipment (crutches) Rehabilitation " + 'services (physical therapy)
Diagnostic tests (' + "ultrasounds and blood work) Specialist visit (anesthesia)
Total " + "Example Cost$12,700Total Example Cost$5,600Total " + 'Example Cost$2,800
In this example, Peg would ' + 'pay:In this example, Joe would pay:In this example, Mia ' + 'would pay:
Cost SharingCost SharingCost Sharing
Deductibles$750" + "Deductibles$120Deductibles$750
Copayments" + '$30Copayments$700' + 'Copayments $400 Coinsurance $30
Coinsurance $' + "1,200 What isn't coveredCoinsurance$0
What isn't " + "coveredWhat isn't covered
Limits or " + "exclusions$20Limits or exclusions$20Limits or " + "exclusions$0
The total Peg would pay is$2,000$840The total Mia would ' + "pay is $1,180
Plan Name: NVIDIA PPO PlanPIan ID: 14603022The plan would be responsible for the other costs of these EXAMPLE covered " + "servicesPage 8 of 8
" + ) + else: + out = ( + '<' + 'tr>
About these Coverage Examples:
' + "This is not a cost depending on the (deductibles, pay under differentestimator. |reatments shown are just examples of how this plan might ' + "cover medical care. Your actual costs will be different actual care you receive, the " + "prices your providers charge, and many other factors. Focus on the cost sharing " + "amounts copayments and coinsurance) and excluded services under the plan. Use this " + "information to compare the portion of costs you might health plans. Please note these " + 'coverage examples are based on self-only coverage.
' + "Peg is Having a Baby (9 months of in-network pre-natal care and a hospital delivery)Managing Joe's type 2 (a year of routine in-network care controlled conaition" + ')Diabetes of a well-Mia\'s Simple Fracture (in-network ' + 'emergency room visit and follow up care)
= The plan' + "'s overall deductible $750 = Specialist copayment $50 = Hospital (facility) " + "coinsurance 10% = Other coinsurance 10%= The plan's overall deductible = " + "Specialist copayment = Hospital (facility) coinsurance = Other coinsurance$" + r"750 $50 10% 10%= The plan's overall deductible = Specialist copayment = " + r"Hospital (facility) coinsurance = Other coinsurance$750 $50 10% 10%
This EXAMPLE event includes services like: ' + "specialist office visits (prenatal care) Childbirth/Delivery Professional Services " + 'Childbirth/Delivery Facility ServicesThis EXAMPLE ' + "event includes services like: Primary care physician office visits (including aisease " + "education) Diagnostic tests (b/ood work) Prescription drugs Durable medical equipment " + '(/g/ucose meter)This EXAMPLE event includes services ' + "like: Emergency room care (including meaical suoplies) Diagnostic test (x-ray) " + "Durable medical equipment (crutches) Rehabilitation services (o/hysical therapy)
Diagnostic tests (u/trasounas and blood work) specialist ' + "visit (anesthesia)
Total Example Cost| $12,700" + "Total Example Cost |$5,600Total Example Cost| $2,800
In this example, Peg would pay:In this example, Joe ' + 'would pay:In this example, Mia would pay:
Cost ' + 'SharingCost SharingCost Sharing
Deductibles$/50Deductibles$120" + "Deductibles$/50
Copayments$30Copayments$/00Copayments $400 Coinsurance $30
Coinsurance $1,200 What isn t covered' + "Coinsurance
What isnt coveredWhat isnt " + "covered
Limits or exclusions$20Limits or " + "exclusions |$20Limits or exclusions
The " + "total Peg would pay is$2,000The total Joe would pay is9840" + 'The total Mia would pay is $1,180
Plan Name: ' + "NVIDIA PPO PlanThe plan would Plan ID: 14603022be responsible for " + "the other costs of theseEXAMPLEcovered services.Page 8 of 8" + "
" + ) + return out + + +@pytest.mark.parametrize( + "input_test, output_test", + [ + ( + [ + { + "label": "table column header", + "score": 0.9349299073219299, + "bbox": [ + 47.83147430419922, + 116.8877944946289, + 2557.79296875, + 216.98883056640625, + ], + }, + { + "label": "table column header", + "score": 0.934, + "bbox": [ + 47.83147430419922, + 116.8877944946289, + 2557.79296875, + 216.98883056640625, + ], + }, + ], + [ + { + "label": "table column header", + "score": 0.9349299073219299, + "bbox": [ + 47.83147430419922, + 116.8877944946289, + 2557.79296875, + 216.98883056640625, + ], + } + ], + ), + ([], []), + ], +) +def test_nms(input_test, output_test): + output = postprocess.nms(input_test) + + assert output == output_test + + +@pytest.mark.parametrize( + "supercell1, supercell2", + [ + ( + { + "label": "table spanning cell", + "score": 0.526617169380188, + "bbox": [1446.2801513671875, 1023.817138671875, 2114.3525390625, 1099.20166015625], + "projected row header": False, + "header": False, + "row_numbers": [3, 4], + "column_numbers": [0, 4], + }, + { + "label": "table spanning cell", + "score": 0.5199193954467773, + "bbox": [ + 98.92312622070312, + 676.1566772460938, + 751.0982666015625, + 938.5986938476562, + ], + "projected row header": False, + "header": False, + "row_numbers": [3, 4, 6], + "column_numbers": [0, 4], + }, + ), + ( + { + "label": "table spanning cell", + "score": 0.526617169380188, + "bbox": [1446.2801513671875, 1023.817138671875, 2114.3525390625, 1099.20166015625], + "projected row header": False, + "header": False, + "row_numbers": [3, 4], + "column_numbers": [0, 4], + }, + { + "label": "table spanning cell", + "score": 0.5199193954467773, + "bbox": [ + 98.92312622070312, + 676.1566772460938, + 751.0982666015625, + 938.5986938476562, + ], + "projected row header": False, + "header": False, + "row_numbers": [4], + "column_numbers": [0, 4, 6], + }, + ), + ( + { + "label": "table spanning cell", + "score": 0.526617169380188, + "bbox": [1446.2801513671875, 1023.817138671875, 2114.3525390625, 1099.20166015625], + "projected row header": False, + "header": False, + "row_numbers": [3, 4], + "column_numbers": [1, 4], + }, + { + "label": "table spanning cell", + "score": 0.5199193954467773, + "bbox": [ + 98.92312622070312, + 676.1566772460938, + 751.0982666015625, + 938.5986938476562, + ], + "projected row header": False, + "header": False, + "row_numbers": [4], + "column_numbers": [0, 4, 6], + }, + ), + ( + { + "label": "table spanning cell", + "score": 0.526617169380188, + "bbox": [1446.2801513671875, 1023.817138671875, 2114.3525390625, 1099.20166015625], + "projected row header": False, + "header": False, + "row_numbers": [3, 4], + "column_numbers": [1, 4], + }, + { + "label": "table spanning cell", + "score": 0.5199193954467773, + "bbox": [ + 98.92312622070312, + 676.1566772460938, + 751.0982666015625, + 938.5986938476562, + ], + "projected row header": False, + "header": False, + "row_numbers": [2, 4, 5, 6, 7, 8], + "column_numbers": [0, 4, 6], + }, + ), + ], +) +def test_remove_supercell_overlap(supercell1, supercell2): + assert postprocess.remove_supercell_overlap(supercell1, supercell2) is None + + +@pytest.mark.parametrize( + ("supercells", "rows", "columns", "output_test"), + [ + ( + [ + { + "label": "table spanning cell", + "score": 0.9, + "bbox": [ + 98.92312622070312, + 143.11549377441406, + 2115.197265625, + 1238.27587890625, + ], + "projected row header": True, + "header": True, + "span": True, + }, + ], + [ + { + "label": "table row", + "score": 0.9299452900886536, + "bbox": [0, 0, 10, 10], + "column header": True, + "header": True, + }, + { + "label": "table row", + "score": 0.9299452900886536, + "bbox": [ + 98.92312622070312, + 143.11549377441406, + 2114.3525390625, + 193.67681884765625, + ], + "column header": True, + "header": True, + }, + { + "label": "table row", + "score": 0.9299452900886536, + "bbox": [ + 98.92312622070312, + 143.11549377441406, + 2114.3525390625, + 193.67681884765625, + ], + "column header": True, + "header": True, + }, + ], + [ + { + "label": "table column", + "score": 0.9996132254600525, + "bbox": [ + 98.92312622070312, + 143.11549377441406, + 517.6508178710938, + 1616.48779296875, + ], + }, + { + "label": "table column", + "score": 0.9935646653175354, + "bbox": [ + 520.0474853515625, + 143.11549377441406, + 751.0982666015625, + 1616.48779296875, + ], + }, + ], + [ + { + "label": "table spanning cell", + "score": 0.9, + "bbox": [ + 98.92312622070312, + 143.11549377441406, + 751.0982666015625, + 193.67681884765625, + ], + "projected row header": True, + "header": True, + "span": True, + "row_numbers": [1, 2], + "column_numbers": [0, 1], + }, + { + "row_numbers": [0], + "column_numbers": [0, 1], + "score": 0.9, + "propagated": True, + "bbox": [ + 98.92312622070312, + 143.11549377441406, + 751.0982666015625, + 193.67681884765625, + ], + }, + ], + ) + ], +) +def test_align_supercells(supercells, rows, columns, output_test): + assert postprocess.align_supercells(supercells, rows, columns) == output_test + + +@pytest.mark.parametrize("rows, bbox, output", [([1.0], [0.0], [1.0])]) +def test_align_rows(rows, bbox, output): + assert postprocess.align_rows(rows, bbox) == output + + +@pytest.mark.parametrize( + ("model_path", "platform_type"), + [ + ("microsoft/table-transformer-structure-recognition", "arm64"), + ("microsoft/table-transformer-structure-recognition", "x86_64"), + ], +) +def test_table_prediction(model_path, sample_table_transcript, platform_type): + with patch("platform.machine", return_value=platform_type): + table_model = tables.UnstructuredTableTransformerModel() + from PIL import Image + + table_model.initialize(model=model_path) + img = Image.open("./sample-docs/example_table.jpg").convert("RGB") + prediction = table_model.predict(img) + assert prediction == sample_table_transcript + + +def test_intersect(): + a = postprocess.Rect() + b = postprocess.Rect([1, 2, 3, 4]) + assert a.intersect(b).get_area() == 4.0 + + +def test_include_rect(): + a = postprocess.Rect() + assert a.include_rect([1, 2, 3, 4]).get_area() == 4.0 + + +@pytest.mark.parametrize( + ("spans", "join_with_space", "expected"), + [ + ( + [ + { + "flags": 2**0, + "text": "5", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + } + ], + True, + "", + ), + ( + [ + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + } + ], + True, + "p", + ), + ( + [ + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + ], + True, + "p p", + ), + ( + [ + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 1, + }, + ], + True, + "p p", + ), + ( + [ + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 0, + }, + { + "flags": 2**0, + "text": "p", + "superscript": False, + "span_num": 0, + "line_num": 0, + "block_num": 1, + }, + ], + False, + "p p", + ), + ], +) +def test_extract_text_from_spans(spans, join_with_space, expected): + res = postprocess.extract_text_from_spans( + spans, join_with_space=join_with_space, remove_integer_superscripts=True + ) + assert res == expected + + +@pytest.mark.parametrize( + ("supercells", "expected_len"), + [ + ([{"header": "hi", "row_numbers": [0, 1, 2], "score": 0.9}], 1), + ( + [ + { + "header": "hi", + "row_numbers": [0], + "column_numbers": [1, 2, 3], + "score": 0.9, + }, + {"header": "hi", "row_numbers": [1], "column_numbers": [1], "score": 0.9}, + {"header": "hi", "row_numbers": [1], "column_numbers": [2], "score": 0.9}, + {"header": "hi", "row_numbers": [1], "column_numbers": [3], "score": 0.9}, + ], + 4, + ), + ( + [ + {"header": "hi", "row_numbers": [0], "column_numbers": [0], "score": 0.9}, + {"header": "hi", "row_numbers": [1], "column_numbers": [0], "score": 0.9}, + {"header": "hi", "row_numbers": [1, 2], "column_numbers": [0], "score": 0.9}, + {"header": "hi", "row_numbers": [3], "column_numbers": [0], "score": 0.9}, + ], + 3, + ), + ], +) +def test_header_supercell_tree(supercells, expected_len): + postprocess.header_supercell_tree(supercells) + assert len(supercells) == expected_len diff --git a/test_unstructured_inference/models/test_yolox.py b/test_unstructured_inference/models/test_yolox.py index ab749b54..b03565d0 100644 --- a/test_unstructured_inference/models/test_yolox.py +++ b/test_unstructured_inference/models/test_yolox.py @@ -22,7 +22,7 @@ def test_layout_yolox_local_parsing_pdf(): filename = os.path.join("sample-docs", "loremipsum.pdf") document_layout = process_file_with_model(filename, model_name="yolox") content = str(document_layout) - assert "Lorem ipsum" in content + assert "libero fringilla" in content assert len(document_layout.pages) == 1 # NOTE(benjamin) The example sent to the test contains 5 detections assert len(document_layout.pages[0].elements) == 5 @@ -57,7 +57,7 @@ def test_layout_yolox_local_parsing_pdf_soft(): filename = os.path.join("sample-docs", "loremipsum.pdf") document_layout = process_file_with_model(filename, model_name="yolox_tiny") content = str(document_layout) - assert "Lorem ipsum" in content + assert "libero fringilla" in content assert len(document_layout.pages) == 1 # NOTE(benjamin) Soft version of the test, run make test-long in order to run with full model assert len(document_layout.pages[0].elements) > 0 diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index b42b08af..53cb2730 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.2.13-dev0" # pragma: no cover +__version__ = "0.2.13" # pragma: no cover diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py index cc7f1cfb..297c7cdf 100644 --- a/unstructured_inference/inference/layout.py +++ b/unstructured_inference/inference/layout.py @@ -14,6 +14,7 @@ from unstructured_inference.inference.elements import TextRegion, ImageTextRegion, LayoutElement from unstructured_inference.logger import logger import unstructured_inference.models.tesseract as tesseract +import unstructured_inference.models.tables as tables from unstructured_inference.models.base import get_model from unstructured_inference.models.unstructuredmodel import UnstructuredModel @@ -54,6 +55,7 @@ def from_file( model: Optional[UnstructuredModel] = None, fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None, ocr_strategy: str = "auto", + extract_tables: bool = False, ) -> DocumentLayout: """Creates a DocumentLayout from a pdf file.""" logger.info(f"Reading PDF for file: {filename} ...") @@ -74,6 +76,7 @@ def from_file( layout=layout, ocr_strategy=ocr_strategy, fixed_layout=fixed_layout, + extract_tables=extract_tables, ) pages.append(page) return cls.from_pages(pages) @@ -85,6 +88,7 @@ def from_image_file( model: Optional[UnstructuredModel] = None, ocr_strategy: str = "auto", fixed_layout: Optional[List[TextRegion]] = None, + extract_tables: bool = False, ) -> DocumentLayout: """Creates a DocumentLayout from an image file.""" logger.info(f"Reading image file: {filename} ...") @@ -96,7 +100,12 @@ def from_image_file( else: raise FileNotFoundError(f'File "{filename}" not found!') from e page = PageLayout.from_image( - image, model=model, layout=None, ocr_strategy=ocr_strategy, fixed_layout=fixed_layout + image, + model=model, + layout=None, + ocr_strategy=ocr_strategy, + fixed_layout=fixed_layout, + extract_tables=extract_tables, ) return cls.from_pages([page]) @@ -111,6 +120,7 @@ def __init__( layout: Optional[List[TextRegion]], model: Optional[UnstructuredModel] = None, ocr_strategy: str = "auto", + extract_tables: bool = False, ): self.image = image self.image_array: Union[np.ndarray, None] = None @@ -121,6 +131,7 @@ def __init__( if ocr_strategy not in VALID_OCR_STRATEGIES: raise ValueError(f"ocr_strategy must be one of {VALID_OCR_STRATEGIES}.") self.ocr_strategy = ocr_strategy + self.extract_tables = extract_tables def __str__(self) -> str: return "\n\n".join([str(element) for element in self.elements]) @@ -148,7 +159,11 @@ def get_elements_from_layout(self, layout: List[TextRegion]) -> List[LayoutEleme layout.sort(key=lambda element: element.y1) elements = [] for e in tqdm(layout): - elements.append(get_element_from_block(e, self.image, self.layout, self.ocr_strategy)) + elements.append( + get_element_from_block( + e, self.image, self.layout, self.ocr_strategy, self.extract_tables + ) + ) return elements def _get_image_array(self) -> Union[np.ndarray, None]: @@ -164,10 +179,18 @@ def from_image( model: Optional[UnstructuredModel] = None, layout: Optional[List[TextRegion]] = None, ocr_strategy: str = "auto", + extract_tables: bool = False, fixed_layout: Optional[List[TextRegion]] = None, ): """Creates a PageLayout from an already-loaded PIL Image.""" - page = cls(number=0, image=image, layout=layout, model=model, ocr_strategy=ocr_strategy) + page = cls( + number=0, + image=image, + layout=layout, + model=model, + ocr_strategy=ocr_strategy, + extract_tables=extract_tables, + ) if fixed_layout is None: page.get_elements() else: @@ -239,12 +262,15 @@ def get_element_from_block( image: Optional[Image.Image] = None, pdf_objects: Optional[List[Union[TextRegion, ImageTextRegion]]] = None, ocr_strategy: str = "auto", + extract_tables: bool = False, ) -> LayoutElement: """Creates a LayoutElement from a given layout or image by finding all the text that lies within a given block.""" if block.text is not None: # If block text is already populated, we'll assume it's correct text = block.text + elif extract_tables and isinstance(block, LayoutElement) and block.type == "Table": + text = interprete_table_block(block, image) elif pdf_objects is not None: text = aggregate_by_block(block, image, pdf_objects, ocr_strategy) elif image is not None: @@ -284,11 +310,21 @@ def aggregate_by_block( return text +def interprete_table_block(text_block: TextRegion, image: Image.Image) -> str: + """Extract the contents of a table.""" + tables.load_agent() + if tables.tables_agent is None: + raise RuntimeError("Unable to load table extraction agent.") + padded_block = text_block.pad(12) + cropped_image = image.crop((padded_block.x1, padded_block.y1, padded_block.x2, padded_block.y2)) + return tables.tables_agent.predict(cropped_image) + + def ocr(text_block: TextRegion, image: Image.Image) -> str: """Runs a cropped text block image through and OCR agent.""" logger.debug("Running OCR on text block ...") tesseract.load_agent() - padded_block = text_block.pad(5) + padded_block = text_block.pad(12) cropped_image = image.crop((padded_block.x1, padded_block.y1, padded_block.x2, padded_block.y2)) return tesseract.ocr_agent.detect(cropped_image) @@ -309,7 +345,7 @@ def load_pdf( vertical_ttb: bool = True, # Should vertical words be read top-to-bottom? extra_attrs: Optional[List[str]] = None, split_at_punctuation: Union[bool, str] = False, - dpi: int = 72, + dpi: int = 200, ) -> Tuple[List[List[TextRegion]], List[Image.Image]]: """Loads the image and word objects from a pdf using pdfplumber and the image renderings of the pdf pages using pdf2image""" diff --git a/unstructured_inference/models/paddle_ocr.py b/unstructured_inference/models/paddle_ocr.py new file mode 100644 index 00000000..3fee51e0 --- /dev/null +++ b/unstructured_inference/models/paddle_ocr.py @@ -0,0 +1,12 @@ +from unstructured_paddleocr import PaddleOCR + +paddle_ocr: PaddleOCR = None + + +def load_agent(): + """Loads the PaddleOCR agent as a global variable to ensure that we only load it once.""" + + global paddle_ocr + paddle_ocr = PaddleOCR(use_angle_cls=True, lang="en", mkl_dnn=True, show_log=False) + + return paddle_ocr diff --git a/unstructured_inference/models/table_postprocess.py b/unstructured_inference/models/table_postprocess.py new file mode 100644 index 00000000..67a59a0c --- /dev/null +++ b/unstructured_inference/models/table_postprocess.py @@ -0,0 +1,623 @@ +# https://github.com/microsoft/table-transformer/blob/main/src/postprocess.py +""" +Copyright (C) 2021 Microsoft Corporation +""" +from collections import defaultdict + + +class Rect: + def __init__(self, bbox=None): + if bbox is None: + self.x_min = 0 + self.y_min = 0 + self.x_max = 0 + self.y_max = 0 + else: + self.x_min = bbox[0] + self.y_min = bbox[1] + self.x_max = bbox[2] + self.y_max = bbox[3] + + def get_area(self): + """Calculates the area of the rectangle""" + area = (self.x_max - self.x_min) * (self.y_max - self.y_min) + return area if area > 0 else 0.0 + + def intersect(self, other): + """Calculates the intersection with another rectangle""" + if self.get_area() == 0: + self.x_min = other.x_min + self.y_min = other.y_min + self.x_max = other.x_max + self.y_max = other.y_max + else: + self.x_min = max(self.x_min, other.x_min) + self.y_min = max(self.y_min, other.y_min) + self.x_max = min(self.x_max, other.x_max) + self.y_max = min(self.y_max, other.y_max) + + if self.x_min > self.x_max or self.y_min > self.y_max or self.get_area() == 0: + self.x_min = 0 + self.y_min = 0 + self.x_max = 0 + self.y_max = 0 + + return self + + def include_rect(self, bbox): + """Calculates a rectangle that includes both rectangles""" + other = Rect(bbox) + + if self.get_area() == 0: + self.x_min = other.x_min + self.y_min = other.y_min + self.x_max = other.x_max + self.y_max = other.y_max + return self + + self.x_min = min(self.x_min, other.x_min) + self.y_min = min(self.y_min, other.y_min) + self.x_max = max(self.x_max, other.x_max) + self.y_max = max(self.y_max, other.y_max) + + # if self.get_area() == 0: + # self.x_min = other.x_min + # self.y_min = other.y_min + # self.x_max = other.x_max + # self.y_max = other.y_max + + return self + + def get_bbox(self): + """Returns the coordinates that define the rectangle""" + return [self.x_min, self.y_min, self.x_max, self.y_max] + + +def apply_threshold(objects, threshold): + """ + Filter out objects below a certain score. + """ + return [obj for obj in objects if obj["score"] >= threshold] + + +# def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds): +# """ +# Filter out bounding boxes whose confidence is below the confidence threshold for +# its associated class label. +# """ +# # Apply class-specific thresholds +# indices_above_threshold = [ +# idx +# for idx, (score, label) in enumerate(zip(scores, labels)) +# if score >= class_thresholds[class_names[label]] +# ] +# bboxes = [bboxes[idx] for idx in indices_above_threshold] +# scores = [scores[idx] for idx in indices_above_threshold] +# labels = [labels[idx] for idx in indices_above_threshold] + +# return bboxes, scores, labels + + +def refine_rows(rows, tokens, score_threshold): + """ + Apply operations to the detected rows, such as + thresholding, NMS, and alignment. + """ + + if len(tokens) > 0: + rows = nms_by_containment(rows, tokens, overlap_threshold=0.5) + remove_objects_without_content(tokens, rows) + else: + rows = nms(rows, match_criteria="object2_overlap", match_threshold=0.5, keep_higher=True) + if len(rows) > 1: + rows = sort_objects_top_to_bottom(rows) + + return rows + + +def refine_columns(columns, tokens, score_threshold): + """ + Apply operations to the detected columns, such as + thresholding, NMS, and alignment. + """ + + if len(tokens) > 0: + columns = nms_by_containment(columns, tokens, overlap_threshold=0.5) + remove_objects_without_content(tokens, columns) + else: + columns = nms( + columns, match_criteria="object2_overlap", match_threshold=0.25, keep_higher=True + ) + if len(columns) > 1: + columns = sort_objects_left_to_right(columns) + + return columns + + +def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5): + """ + Non-maxima suppression (NMS) of objects based on shared containment of other objects. + """ + container_objects = sort_objects_by_score(container_objects) + num_objects = len(container_objects) + suppression = [False for obj in container_objects] + + packages_by_container, _, _ = slot_into_containers( + container_objects, + package_objects, + overlap_threshold=overlap_threshold, + forced_assignment=False, + ) + + for object2_num in range(1, num_objects): + object2_packages = set(packages_by_container[object2_num]) + if len(object2_packages) == 0: + suppression[object2_num] = True + for object1_num in range(object2_num): + if not suppression[object1_num]: + object1_packages = set(packages_by_container[object1_num]) + if len(object2_packages.intersection(object1_packages)) > 0: + suppression[object2_num] = True + + final_objects = [obj for idx, obj in enumerate(container_objects) if not suppression[idx]] + return final_objects + + +def slot_into_containers( + container_objects, + package_objects, + overlap_threshold=0.5, + forced_assignment=False, +): + """ + Slot a collection of objects into the container they occupy most (the container which holds the + largest fraction of the object). + """ + best_match_scores = [] + + container_assignments = [[] for container in container_objects] + package_assignments = [[] for package in package_objects] + + if len(container_objects) == 0 or len(package_objects) == 0: + return container_assignments, package_assignments, best_match_scores + + match_scores = defaultdict(dict) + for package_num, package in enumerate(package_objects): + match_scores = [] + package_rect = Rect(package["bbox"]) + package_area = package_rect.get_area() + for container_num, container in enumerate(container_objects): + container_rect = Rect(container["bbox"]) + intersect_area = container_rect.intersect(Rect(package["bbox"])).get_area() + overlap_fraction = intersect_area / package_area + + match_scores.append( + {"container": container, "container_num": container_num, "score": overlap_fraction} + ) + + sorted_match_scores = sort_objects_by_score(match_scores) + + best_match_score = sorted_match_scores[0] + best_match_scores.append(best_match_score["score"]) + if forced_assignment or best_match_score["score"] >= overlap_threshold: + container_assignments[best_match_score["container_num"]].append(package_num) + package_assignments[package_num].append(best_match_score["container_num"]) + + return container_assignments, package_assignments, best_match_scores + + +def sort_objects_by_score(objects, reverse=True): + """ + Put any set of objects in order from high score to low score. + """ + return sorted(objects, key=lambda k: k["score"], reverse=reverse) + + +def remove_objects_without_content(page_spans, objects): + """ + Remove any objects (these can be rows, columns, supercells, etc.) that don't + have any text associated with them. + """ + for obj in objects[:]: + object_text, _ = extract_text_inside_bbox(page_spans, obj["bbox"]) + if len(object_text.strip()) == 0: + objects.remove(obj) + + +def extract_text_inside_bbox(spans, bbox): + """ + Extract the text inside a bounding box. + """ + bbox_spans = get_bbox_span_subset(spans, bbox) + bbox_text = extract_text_from_spans(bbox_spans, remove_integer_superscripts=True) + + return bbox_text, bbox_spans + + +def get_bbox_span_subset(spans, bbox, threshold=0.5): + """ + Reduce the set of spans to those that fall within a bounding box. + + threshold: the fraction of the span that must overlap with the bbox. + """ + span_subset = [] + for span in spans: + if overlaps(span["bbox"], bbox, threshold): + span_subset.append(span) + return span_subset + + +def overlaps(bbox1, bbox2, threshold=0.5): + """ + Test if more than "threshold" fraction of bbox1 overlaps with bbox2. + """ + rect1 = Rect(list(bbox1)) + area1 = rect1.get_area() + if area1 == 0: + return False + return rect1.intersect(Rect(list(bbox2))).get_area() / area1 >= threshold + + +def extract_text_from_spans(spans, join_with_space=True, remove_integer_superscripts=True): + """ + Convert a collection of page tokens/words/spans into a single text string. + """ + + join_char = " " if join_with_space else "" + spans_copy = spans[:] + + if remove_integer_superscripts: + for span in spans: + if "flags" not in span: + continue + flags = span["flags"] + if flags & 2**0: # superscript flag + if span["text"].strip().isdigit(): + spans_copy.remove(span) + else: + span["superscript"] = True + + if len(spans_copy) == 0: + return "" + + spans_copy.sort(key=lambda span: span["span_num"]) + spans_copy.sort(key=lambda span: span["line_num"]) + spans_copy.sort(key=lambda span: span["block_num"]) + + # Force the span at the end of every line within a block to have exactly one space + # unless the line ends with a space or ends with a non-space followed by a hyphen + line_texts = [] + line_span_texts = [spans_copy[0]["text"]] + for span1, span2 in zip(spans_copy[:-1], spans_copy[1:]): + if ( + not span1["block_num"] == span2["block_num"] + or not span1["line_num"] == span2["line_num"] + ): + line_text = join_char.join(line_span_texts).strip() + if ( + len(line_text) > 0 + and not line_text[-1] == " " + and not (len(line_text) > 1 and line_text[-1] == "-" and not line_text[-2] == " ") + ): + if not join_with_space: + line_text += " " + line_texts.append(line_text) + line_span_texts = [span2["text"]] + else: + line_span_texts.append(span2["text"]) + line_text = join_char.join(line_span_texts) + line_texts.append(line_text) + + return join_char.join(line_texts).strip() + + +def sort_objects_left_to_right(objs): + """ + Put the objects in order from left to right. + """ + return sorted(objs, key=lambda k: k["bbox"][0] + k["bbox"][2]) + + +def sort_objects_top_to_bottom(objs): + """ + Put the objects in order from top to bottom. + """ + return sorted(objs, key=lambda k: k["bbox"][1] + k["bbox"][3]) + + +def align_columns(columns, bbox): + """ + For every column, align the top and bottom boundaries to the final + table bounding box. + """ + try: + for column in columns: + column["bbox"][1] = bbox[1] + column["bbox"][3] = bbox[3] + except Exception as err: + print("Could not align columns: {}".format(err)) + pass + + return columns + + +def align_rows(rows, bbox): + """ + For every row, align the left and right boundaries to the final + table bounding box. + """ + try: + for row in rows: + row["bbox"][0] = bbox[0] + row["bbox"][2] = bbox[2] + except Exception as err: + print("Could not align rows: {}".format(err)) + pass + + return rows + + +def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_higher=True): + """ + A customizable version of non-maxima suppression (NMS). + + Default behavior: If a lower-confidence object overlaps more than 5% of its area + with a higher-confidence object, remove the lower-confidence object. + + objects: set of dicts; each object dict must have a 'bbox' and a 'score' field + match_criteria: how to measure how much two objects "overlap" + match_threshold: the cutoff for determining that overlap requires suppression of one object + keep_higher: if True, keep the object with the higher metric; otherwise, keep the lower + """ + if len(objects) == 0: + return [] + + objects = sort_objects_by_score(objects, reverse=keep_higher) + + num_objects = len(objects) + suppression = [False for obj in objects] + + for object2_num in range(1, num_objects): + object2_rect = Rect(objects[object2_num]["bbox"]) + object2_area = object2_rect.get_area() + for object1_num in range(object2_num): + if not suppression[object1_num]: + object1_rect = Rect(objects[object1_num]["bbox"]) + object1_area = object1_rect.get_area() + intersect_area = object1_rect.intersect(object2_rect).get_area() + try: + if match_criteria == "object1_overlap": + metric = intersect_area / object1_area + elif match_criteria == "object2_overlap": + metric = intersect_area / object2_area + elif match_criteria == "iou": + metric = intersect_area / (object1_area + object2_area - intersect_area) + if metric >= match_threshold: + suppression[object2_num] = True + break + except ZeroDivisionError: + # Intended to recover from divide-by-zero + pass + + return [obj for idx, obj in enumerate(objects) if not suppression[idx]] + + +def align_supercells(supercells, rows, columns): + """ + For each supercell, align it to the rows it intersects 50% of the height of, + and the columns it intersects 50% of the width of. + Eliminate supercells for which there are no rows and columns it intersects 50% with. + """ + aligned_supercells = [] + + for supercell in supercells: + supercell["header"] = False + row_bbox_rect = None + col_bbox_rect = None + intersecting_header_rows = set() + intersecting_data_rows = set() + for row_num, row in enumerate(rows): + row_height = row["bbox"][3] - row["bbox"][1] + supercell_height = supercell["bbox"][3] - supercell["bbox"][1] + min_row_overlap = max(row["bbox"][1], supercell["bbox"][1]) + max_row_overlap = min(row["bbox"][3], supercell["bbox"][3]) + overlap_height = max_row_overlap - min_row_overlap + if "span" in supercell: + overlap_fraction = max( + overlap_height / row_height, overlap_height / supercell_height + ) + else: + overlap_fraction = overlap_height / row_height + if overlap_fraction >= 0.5: + if "header" in row and row["header"]: + intersecting_header_rows.add(row_num) + else: + intersecting_data_rows.add(row_num) + + # Supercell cannot span across the header boundary; eliminate whichever + # group of rows is the smallest + supercell["header"] = False + if len(intersecting_data_rows) > 0 and len(intersecting_header_rows) > 0: + if len(intersecting_data_rows) > len(intersecting_header_rows): + intersecting_header_rows = set() + else: + intersecting_data_rows = set() + if len(intersecting_header_rows) > 0: + supercell["header"] = True + elif "span" in supercell: + continue # Require span supercell to be in the header + intersecting_rows = intersecting_data_rows.union(intersecting_header_rows) + # Determine vertical span of aligned supercell + for row_num in intersecting_rows: + if row_bbox_rect is None: + row_bbox_rect = Rect(rows[row_num]["bbox"]) + else: + row_bbox_rect = row_bbox_rect.include_rect(rows[row_num]["bbox"]) + if row_bbox_rect is None: + continue + + intersecting_cols = [] + for col_num, col in enumerate(columns): + col_width = col["bbox"][2] - col["bbox"][0] + supercell_width = supercell["bbox"][2] - supercell["bbox"][0] + min_col_overlap = max(col["bbox"][0], supercell["bbox"][0]) + max_col_overlap = min(col["bbox"][2], supercell["bbox"][2]) + overlap_width = max_col_overlap - min_col_overlap + if "span" in supercell: + overlap_fraction = max(overlap_width / col_width, overlap_width / supercell_width) + # Multiply by 2 effectively lowers the threshold to 0.25 + if supercell["header"]: + overlap_fraction = overlap_fraction * 2 + else: + overlap_fraction = overlap_width / col_width + if overlap_fraction >= 0.5: + intersecting_cols.append(col_num) + if col_bbox_rect is None: + col_bbox_rect = Rect(col["bbox"]) + else: + col_bbox_rect = col_bbox_rect.include_rect(col["bbox"]) + if col_bbox_rect is None: + continue + + supercell_bbox = row_bbox_rect.intersect(col_bbox_rect).get_bbox() + supercell["bbox"] = supercell_bbox + + # Only a true supercell if it joins across multiple rows or columns + if ( + len(intersecting_rows) > 0 + and len(intersecting_cols) > 0 + and (len(intersecting_rows) > 1 or len(intersecting_cols) > 1) + ): + supercell["row_numbers"] = list(intersecting_rows) + supercell["column_numbers"] = intersecting_cols + aligned_supercells.append(supercell) + + # A span supercell in the header means there must be supercells above it in the header + if "span" in supercell and supercell["header"] and len(supercell["column_numbers"]) > 1: + for row_num in range(0, min(supercell["row_numbers"])): + new_supercell = { + "row_numbers": [row_num], + "column_numbers": supercell["column_numbers"], + "score": supercell["score"], + "propagated": True, + } + new_supercell_columns = [columns[idx] for idx in supercell["column_numbers"]] + new_supercell_rows = [rows[idx] for idx in supercell["row_numbers"]] + bbox = [ + min([column["bbox"][0] for column in new_supercell_columns]), + min([row["bbox"][1] for row in new_supercell_rows]), + max([column["bbox"][2] for column in new_supercell_columns]), + max([row["bbox"][3] for row in new_supercell_rows]), + ] + new_supercell["bbox"] = bbox + aligned_supercells.append(new_supercell) + + return aligned_supercells + + +def nms_supercells(supercells): + """ + A NMS scheme for supercells that first attempts to shrink supercells to + resolve overlap. + If two supercells overlap the same (sub)cell, shrink the lower confidence + supercell to resolve the overlap. If shrunk supercell is empty, remove it. + """ + + supercells = sort_objects_by_score(supercells) + num_supercells = len(supercells) + suppression = [False for supercell in supercells] + + for supercell2_num in range(1, num_supercells): + supercell2 = supercells[supercell2_num] + for supercell1_num in range(supercell2_num): + supercell1 = supercells[supercell1_num] + remove_supercell_overlap(supercell1, supercell2) + if ( + (len(supercell2["row_numbers"]) < 2 and len(supercell2["column_numbers"]) < 2) + or len(supercell2["row_numbers"]) == 0 + or len(supercell2["column_numbers"]) == 0 + ): + suppression[supercell2_num] = True + + return [obj for idx, obj in enumerate(supercells) if not suppression[idx]] + + +def header_supercell_tree(supercells): + """ + Make sure no supercell in the header is below more than one supercell in any row above it. + The cells in the header form a tree, but a supercell with more than one supercell in a row + above it means that some cell has more than one parent, which is not allowed. Eliminate + any supercell that would cause this to be violated. + """ + header_supercells = [ + supercell for supercell in supercells if "header" in supercell and supercell["header"] + ] + header_supercells = sort_objects_by_score(header_supercells) + + for header_supercell in header_supercells[:]: + ancestors_by_row = defaultdict(int) + min_row = min(header_supercell["row_numbers"]) + for header_supercell2 in header_supercells: + max_row2 = max(header_supercell2["row_numbers"]) + if max_row2 < min_row: + if set(header_supercell["column_numbers"]).issubset( + set(header_supercell2["column_numbers"]) + ): + for row2 in header_supercell2["row_numbers"]: + ancestors_by_row[row2] += 1 + for row in range(0, min_row): + if not ancestors_by_row[row] == 1: + supercells.remove(header_supercell) + break + + +def remove_supercell_overlap(supercell1, supercell2): + """ + This function resolves overlap between supercells (supercells must be + disjoint) by iteratively shrinking supercells by the fewest grid cells + necessary to resolve the overlap. + Example: + If two supercells overlap at grid cell (R, C), and supercell #1 is less + confident than supercell #2, we eliminate either row R from supercell #1 + or column C from supercell #1 by comparing the number of columns in row R + versus the number of rows in column C. If the number of columns in row R + is less than the number of rows in column C, we eliminate row R from + supercell #1. This resolves the overlap by removing fewer grid cells from + supercell #1 than if we eliminated column C from it. + """ + common_rows = set(supercell1["row_numbers"]).intersection(set(supercell2["row_numbers"])) + common_columns = set(supercell1["column_numbers"]).intersection( + set(supercell2["column_numbers"]) + ) + + # While the supercells have overlapping grid cells, continue shrinking the less-confident + # supercell one row or one column at a time + while len(common_rows) > 0 and len(common_columns) > 0: + # Try to shrink the supercell as little as possible to remove the overlap; + # if the supercell has fewer rows than columns, remove an overlapping column, + # because this removes fewer grid cells from the supercell; + # otherwise remove an overlapping row + if len(supercell2["row_numbers"]) < len(supercell2["column_numbers"]): + min_column = min(supercell2["column_numbers"]) + max_column = max(supercell2["column_numbers"]) + if max_column in common_columns: + common_columns.remove(max_column) + supercell2["column_numbers"].remove(max_column) + elif min_column in common_columns: + common_columns.remove(min_column) + supercell2["column_numbers"].remove(min_column) + else: + supercell2["column_numbers"] = [] + common_columns = set() + else: + min_row = min(supercell2["row_numbers"]) + max_row = max(supercell2["row_numbers"]) + if max_row in common_rows: + common_rows.remove(max_row) + supercell2["row_numbers"].remove(max_row) + elif min_row in common_rows: + common_rows.remove(min_row) + supercell2["row_numbers"].remove(min_row) + else: + supercell2["row_numbers"] = [] + common_rows = set() diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py new file mode 100644 index 00000000..c90c359d --- /dev/null +++ b/unstructured_inference/models/tables.py @@ -0,0 +1,623 @@ +# https://github.com/microsoft/table-transformer/blob/main/src/inference.py +# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Using_Table_Transformer_for_table_detection_and_table_structure_recognition.ipynb +import torch +import logging + +from unstructured_inference.models.unstructuredmodel import UnstructuredModel +from unstructured_inference.logger import logger + +from collections import defaultdict +import xml.etree.ElementTree as ET + +import cv2 +import numpy as np +import pandas as pd + +import pytesseract + +from transformers import TableTransformerForObjectDetection +from transformers import DetrImageProcessor +from PIL import Image +from typing import Union, Optional +from pathlib import Path +import platform + +from . import table_postprocess as postprocess +from unstructured_inference.models.table_postprocess import Rect + + +class UnstructuredTableTransformerModel(UnstructuredModel): + """Unstructured model wrapper for table-transformer.""" + + def __init__(self): + pass + + def predict(self, x: Image): + """Predict table structure deferring to run_prediction""" + super().predict(x) + return self.run_prediction(x) + + def initialize( + self, + model: Union[str, Path, TableTransformerForObjectDetection] = None, + device: Optional[str] = "cuda" if torch.cuda.is_available() else "cpu", + ): + """Loads the donut model using the specified parameters""" + self.device = device + self.feature_extractor = DetrImageProcessor() + + try: + logging.info("Loading the table structure model ...") + self.model = TableTransformerForObjectDetection.from_pretrained(model) + self.model.eval() + + except EnvironmentError: + logging.critical("Failed to initialize the model.") + logging.critical("Ensure that the model is correct") + raise ImportError( + "Review the parameters to initialize a UnstructuredTableTransformerModel obj" + ) + self.model.to(device) + + def run_prediction(self, x: Image): + """Predict table structure""" + with torch.no_grad(): + encoding = self.feature_extractor(x, return_tensors="pt").to(self.device) + outputs_structure = self.model(**encoding) + + if platform.machine() == "x86_64": + from unstructured_inference.models import paddle_ocr + + paddle_result = paddle_ocr.load_agent().ocr(np.array(x), cls=True) + + tokens = [] + for idx in range(len(paddle_result)): + res = paddle_result[idx] + for line in res: + xmin = min([i[0] for i in line[0]]) + ymin = min([i[1] for i in line[0]]) + xmax = max([i[0] for i in line[0]]) + ymax = max([i[1] for i in line[0]]) + tokens.append({"bbox": [xmin, ymin, xmax, ymax], "text": line[1][0]}) + else: + zoom = 6 + img = cv2.resize( + cv2.cvtColor(np.array(x), cv2.COLOR_RGB2BGR), + None, + fx=zoom, + fy=zoom, + interpolation=cv2.INTER_CUBIC, + ) + + kernel = np.ones((1, 1), np.uint8) + img = cv2.dilate(img, kernel, iterations=1) + img = cv2.erode(img, kernel, iterations=1) + + ocr_df: pd.DataFrame = pytesseract.image_to_data( + Image.fromarray(img), output_type="data.frame" + ) + + ocr_df = ocr_df.dropna() + + tokens = [] + for idtx in ocr_df.itertuples(): + tokens.append( + { + "bbox": [ + idtx.left / zoom, + idtx.top / zoom, + (idtx.left + idtx.width) / zoom, + (idtx.top + idtx.height) / zoom, + ], + "text": idtx.text, + } + ) + + sorted(tokens, key=lambda x: x["bbox"][1] * 10000 + x["bbox"][0]) + + # 'tokens' is a list of tokens + # Need to be in a relative reading order + # If no order is provided, use current order + for idx, token in enumerate(tokens): + if "span_num" not in token: + token["span_num"] = idx + if "line_num" not in token: + token["line_num"] = 0 + if "block_num" not in token: + token["block_num"] = 0 + + html = recognize(outputs_structure, x, tokens=tokens, out_html=True)["html"] + prediction = html[0] if html else "" + return prediction + + +tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel() + + +def load_agent(): + """Loads the Tesseract OCR agent as a global variable to ensure that we only load it once.""" + global tables_agent + + if not hasattr(tables_agent, "model"): + logger.info("Loading the Tesseract OCR agent ...") + tables_agent.initialize("microsoft/table-transformer-structure-recognition") + + return + + +def get_class_map(data_type: str): + """Defines class map dictionaries""" + if data_type == "structure": + class_map = { + "table": 0, + "table column": 1, + "table row": 2, + "table column header": 3, + "table projected row header": 4, + "table spanning cell": 5, + "no object": 6, + } + elif data_type == "detection": + class_map = {"table": 0, "table rotated": 1, "no object": 2} + return class_map + + +structure_class_thresholds = { + "table": 0.5, + "table column": 0.5, + "table row": 0.5, + "table column header": 0.5, + "table projected row header": 0.5, + "table spanning cell": 0.5, + "no object": 10, +} + + +def recognize(outputs: dict, img: Image, tokens: list, out_html: bool = False): + """Recognize table elements.""" + out_formats = {} + + str_class_name2idx = get_class_map("structure") + str_class_idx2name = {v: k for k, v in str_class_name2idx.items()} + str_class_thresholds = structure_class_thresholds + + # Post-process detected objects, assign class labels + objects = outputs_to_objects(outputs, img.size, str_class_idx2name) + + # Further process the detected objects so they correspond to a consistent table + tables_structure = objects_to_structures(objects, tokens, str_class_thresholds) + # Enumerate all table cells: grid cells and spanning cells + tables_cells = [structure_to_cells(structure, tokens)[0] for structure in tables_structure] + + # Convert cells to HTML + if out_html: + tables_htmls = [cells_to_html(cells) for cells in tables_cells] + out_formats["html"] = tables_htmls + + return out_formats + + +def outputs_to_objects(outputs, img_size, class_idx2name): + """Output table element types.""" + m = outputs["logits"].softmax(-1).max(-1) + pred_labels = list(m.indices.detach().cpu().numpy())[0] + pred_scores = list(m.values.detach().cpu().numpy())[0] + pred_bboxes = outputs["pred_boxes"].detach().cpu()[0] + pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)] + + objects = [] + for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): + class_label = class_idx2name[int(label)] + if not class_label == "no object": + objects.append( + { + "label": class_label, + "score": float(score), + "bbox": [float(elem) for elem in bbox], + } + ) + + return objects + + +# for output bounding box post-processing +def box_cxcywh_to_xyxy(x): + """Convert rectangle format from center-x, center-y, width, height to + x-min, y-min, x-max, y-max.""" + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=1) + + +def rescale_bboxes(out_bbox, size): + """Rescale relative bounding box to box of size given by size.""" + img_w, img_h = size + b = box_cxcywh_to_xyxy(out_bbox) + b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) + return b + + +def iob(bbox1, bbox2): + """ + Compute the intersection area over box area, for bbox1. + """ + intersection = Rect(bbox1).intersect(Rect(bbox2)) + + bbox1_area = Rect(bbox1).get_area() + if bbox1_area > 0: + return intersection.get_area() / bbox1_area + + return 0 + + +def objects_to_structures(objects, tokens, class_thresholds): + """ + Process the bounding boxes produced by the table structure recognition model into + a *consistent* set of table structures (rows, columns, spanning cells, headers). + This entails resolving conflicts/overlaps, and ensuring the boxes meet certain alignment + conditions (for example: rows should all have the same width, etc.). + """ + + tables = [obj for obj in objects if obj["label"] == "table"] + table_structures = [] + + for table in tables: + table_objects = [obj for obj in objects if iob(obj["bbox"], table["bbox"]) >= 0.5] + table_tokens = [token for token in tokens if iob(token["bbox"], table["bbox"]) >= 0.5] + + structure = {} + + columns = [obj for obj in table_objects if obj["label"] == "table column"] + rows = [obj for obj in table_objects if obj["label"] == "table row"] + column_headers = [obj for obj in table_objects if obj["label"] == "table column header"] + spanning_cells = [obj for obj in table_objects if obj["label"] == "table spanning cell"] + for obj in spanning_cells: + obj["projected row header"] = False + projected_row_headers = [ + obj for obj in table_objects if obj["label"] == "table projected row header" + ] + for obj in projected_row_headers: + obj["projected row header"] = True + spanning_cells += projected_row_headers + for obj in rows: + obj["column header"] = False + for header_obj in column_headers: + if iob(obj["bbox"], header_obj["bbox"]) >= 0.5: + obj["column header"] = True + + # Refine table structures + rows = postprocess.refine_rows(rows, table_tokens, class_thresholds["table row"]) + columns = postprocess.refine_columns( + columns, table_tokens, class_thresholds["table column"] + ) + + # Shrink table bbox to just the total height of the rows + # and the total width of the columns + row_rect = Rect() + for obj in rows: + row_rect.include_rect(obj["bbox"]) + column_rect = Rect() + for obj in columns: + column_rect.include_rect(obj["bbox"]) + table["row_column_bbox"] = [ + column_rect.x_min, + row_rect.y_min, + column_rect.x_max, + row_rect.y_max, + ] + table["bbox"] = table["row_column_bbox"] + + # Process the rows and columns into a complete segmented table + columns = postprocess.align_columns(columns, table["row_column_bbox"]) + rows = postprocess.align_rows(rows, table["row_column_bbox"]) + + structure["rows"] = rows + structure["columns"] = columns + structure["column headers"] = column_headers + structure["spanning cells"] = spanning_cells + + if len(rows) > 0 and len(columns) > 1: + structure = refine_table_structure(structure, class_thresholds) + + table_structures.append(structure) + + return table_structures + + +def refine_table_structure(table_structure, class_thresholds): + """ + Apply operations to the detected table structure objects such as + thresholding, NMS, and alignment. + """ + rows = table_structure["rows"] + columns = table_structure["columns"] + + # Process the headers + column_headers = table_structure["column headers"] + column_headers = postprocess.apply_threshold( + column_headers, class_thresholds["table column header"] + ) + column_headers = postprocess.nms(column_headers) + column_headers = align_headers(column_headers, rows) + + # Process spanning cells + spanning_cells = [ + elem for elem in table_structure["spanning cells"] if not elem["projected row header"] + ] + projected_row_headers = [ + elem for elem in table_structure["spanning cells"] if elem["projected row header"] + ] + spanning_cells = postprocess.apply_threshold( + spanning_cells, class_thresholds["table spanning cell"] + ) + projected_row_headers = postprocess.apply_threshold( + projected_row_headers, class_thresholds["table projected row header"] + ) + spanning_cells += projected_row_headers + # Align before NMS for spanning cells because alignment brings them into agreement + # with rows and columns first; if spanning cells still overlap after this operation, + # the threshold for NMS can basically be lowered to just above 0 + spanning_cells = postprocess.align_supercells(spanning_cells, rows, columns) + spanning_cells = postprocess.nms_supercells(spanning_cells) + + postprocess.header_supercell_tree(spanning_cells) + + table_structure["columns"] = columns + table_structure["rows"] = rows + table_structure["spanning cells"] = spanning_cells + table_structure["column headers"] = column_headers + + return table_structure + + +def align_headers(headers, rows): + """ + Adjust the header boundary to be the convex hull of the rows it intersects + at least 50% of the height of. + + For now, we are not supporting tables with multiple headers, so we need to + eliminate anything besides the top-most header. + """ + + aligned_headers = [] + + for row in rows: + row["column header"] = False + + header_row_nums = [] + for header in headers: + for row_num, row in enumerate(rows): + row_height = row["bbox"][3] - row["bbox"][1] + min_row_overlap = max(row["bbox"][1], header["bbox"][1]) + max_row_overlap = min(row["bbox"][3], header["bbox"][3]) + overlap_height = max_row_overlap - min_row_overlap + if overlap_height / row_height >= 0.5: + header_row_nums.append(row_num) + + if len(header_row_nums) == 0: + return aligned_headers + + header_rect = Rect() + if header_row_nums[0] > 0: + header_row_nums = list(range(header_row_nums[0] + 1)) + header_row_nums + + last_row_num = -1 + for row_num in header_row_nums: + if row_num == last_row_num + 1: + row = rows[row_num] + row["column header"] = True + header_rect = header_rect.include_rect(row["bbox"]) + last_row_num = row_num + else: + # Break as soon as a non-header row is encountered. + # This ignores any subsequent rows in the table labeled as a header. + # Having more than 1 header is not supported currently. + break + + header = {"bbox": header_rect.get_bbox()} + aligned_headers.append(header) + + return aligned_headers + + +def structure_to_cells(table_structure, tokens): + """ + Assuming the row, column, spanning cell, and header bounding boxes have + been refined into a set of consistent table structures, process these + table structures into table cells. This is a universal representation + format for the table, which can later be exported to Pandas or CSV formats. + Classify the cells as header/access cells or data cells + based on if they intersect with the header bounding box. + """ + columns = table_structure["columns"] + rows = table_structure["rows"] + spanning_cells = table_structure["spanning cells"] + cells = [] + subcells = [] + # Identify complete cells and subcells + for column_num, column in enumerate(columns): + for row_num, row in enumerate(rows): + column_rect = Rect(list(column["bbox"])) + row_rect = Rect(list(row["bbox"])) + cell_rect = row_rect.intersect(column_rect) + header = "column header" in row and row["column header"] + cell = { + "bbox": cell_rect.get_bbox(), + "column_nums": [column_num], + "row_nums": [row_num], + "column header": header, + } + + cell["subcell"] = False + for spanning_cell in spanning_cells: + spanning_cell_rect = Rect(list(spanning_cell["bbox"])) + if ( + spanning_cell_rect.intersect(cell_rect).get_area() / cell_rect.get_area() + ) > 0.5: + cell["subcell"] = True + break + + if cell["subcell"]: + subcells.append(cell) + else: + # cell text = extract_text_inside_bbox(table_spans, cell['bbox']) + # cell['cell text'] = cell text + cell["projected row header"] = False + cells.append(cell) + + for spanning_cell in spanning_cells: + spanning_cell_rect = Rect(list(spanning_cell["bbox"])) + cell_columns = set() + cell_rows = set() + cell_rect = None + header = True + for subcell in subcells: + subcell_rect = Rect(list(subcell["bbox"])) + subcell_rect_area = subcell_rect.get_area() + if (subcell_rect.intersect(spanning_cell_rect).get_area() / subcell_rect_area) > 0.5: + if cell_rect is None: + cell_rect = Rect(list(subcell["bbox"])) + else: + cell_rect.include_rect(list(subcell["bbox"])) + cell_rows = cell_rows.union(set(subcell["row_nums"])) + cell_columns = cell_columns.union(set(subcell["column_nums"])) + # By convention here, all subcells must be classified + # as header cells for a spanning cell to be classified as a header cell; + # otherwise, this could lead to a non-rectangular header region + header = header and "column header" in subcell and subcell["column header"] + if len(cell_rows) > 0 and len(cell_columns) > 0: + cell = { + "bbox": cell_rect.get_bbox(), + "column_nums": list(cell_columns), + "row_nums": list(cell_rows), + "column header": header, + "projected row header": spanning_cell["projected row header"], + } + cells.append(cell) + + # Compute a confidence score based on how well the page tokens + # slot into the cells reported by the model + _, _, cell_match_scores = postprocess.slot_into_containers(cells, tokens) + try: + mean_match_score = sum(cell_match_scores) / len(cell_match_scores) + min_match_score = min(cell_match_scores) + confidence_score = (mean_match_score + min_match_score) / 2 + except ZeroDivisionError: + confidence_score = 0 + + # Dilate rows and columns before final extraction + # dilated_columns = fill_column_gaps(columns, table_bbox) + dilated_columns = columns + # dilated_rows = fill_row_gaps(rows, table_bbox) + dilated_rows = rows + for cell in cells: + column_rect = Rect() + for column_num in cell["column_nums"]: + column_rect.include_rect(list(dilated_columns[column_num]["bbox"])) + row_rect = Rect() + for row_num in cell["row_nums"]: + row_rect.include_rect(list(dilated_rows[row_num]["bbox"])) + cell_rect = column_rect.intersect(row_rect) + cell["bbox"] = cell_rect.get_bbox() + + span_nums_by_cell, _, _ = postprocess.slot_into_containers( + cells, tokens, overlap_threshold=0.001, forced_assignment=False + ) + + for cell, cell_span_nums in zip(cells, span_nums_by_cell): + cell_spans = [tokens[num] for num in cell_span_nums] + # TODO: Refine how text is extracted; should be character-based, not span-based; + # but need to associate + cell["cell text"] = postprocess.extract_text_from_spans( + cell_spans, remove_integer_superscripts=False + ) + cell["spans"] = cell_spans + + # Adjust the row, column, and cell bounding boxes to reflect the extracted text + num_rows = len(rows) + rows = postprocess.sort_objects_top_to_bottom(rows) + num_columns = len(columns) + columns = postprocess.sort_objects_left_to_right(columns) + min_y_values_by_row = defaultdict(list) + max_y_values_by_row = defaultdict(list) + min_x_values_by_column = defaultdict(list) + max_x_values_by_column = defaultdict(list) + for cell in cells: + min_row = min(cell["row_nums"]) + max_row = max(cell["row_nums"]) + min_column = min(cell["column_nums"]) + max_column = max(cell["column_nums"]) + for span in cell["spans"]: + min_x_values_by_column[min_column].append(span["bbox"][0]) + min_y_values_by_row[min_row].append(span["bbox"][1]) + max_x_values_by_column[max_column].append(span["bbox"][2]) + max_y_values_by_row[max_row].append(span["bbox"][3]) + for row_num, row in enumerate(rows): + if len(min_x_values_by_column[0]) > 0: + row["bbox"][0] = min(min_x_values_by_column[0]) + if len(min_y_values_by_row[row_num]) > 0: + row["bbox"][1] = min(min_y_values_by_row[row_num]) + if len(max_x_values_by_column[num_columns - 1]) > 0: + row["bbox"][2] = max(max_x_values_by_column[num_columns - 1]) + if len(max_y_values_by_row[row_num]) > 0: + row["bbox"][3] = max(max_y_values_by_row[row_num]) + for column_num, column in enumerate(columns): + if len(min_x_values_by_column[column_num]) > 0: + column["bbox"][0] = min(min_x_values_by_column[column_num]) + if len(min_y_values_by_row[0]) > 0: + column["bbox"][1] = min(min_y_values_by_row[0]) + if len(max_x_values_by_column[column_num]) > 0: + column["bbox"][2] = max(max_x_values_by_column[column_num]) + if len(max_y_values_by_row[num_rows - 1]) > 0: + column["bbox"][3] = max(max_y_values_by_row[num_rows - 1]) + for cell in cells: + row_rect = None + column_rect = None + for row_num in cell["row_nums"]: + if row_rect is None: + row_rect = Rect(list(rows[row_num]["bbox"])) + else: + row_rect.include_rect(list(rows[row_num]["bbox"])) + for column_num in cell["column_nums"]: + if column_rect is None: + column_rect = Rect(list(columns[column_num]["bbox"])) + else: + column_rect.include_rect(list(columns[column_num]["bbox"])) + cell_rect = row_rect.intersect(column_rect) + if cell_rect.get_area() > 0: + cell["bbox"] = cell_rect.get_bbox() + pass + + return cells, confidence_score + + +def cells_to_html(cells): + """Convert table structure to html format.""" + cells = sorted(cells, key=lambda k: min(k["column_nums"])) + cells = sorted(cells, key=lambda k: min(k["row_nums"])) + + table = ET.Element("table") + current_row = -1 + + for cell in cells: + this_row = min(cell["row_nums"]) + + attrib = {} + colspan = len(cell["column_nums"]) + if colspan > 1: + attrib["colspan"] = str(colspan) + rowspan = len(cell["row_nums"]) + if rowspan > 1: + attrib["rowspan"] = str(rowspan) + if this_row > current_row: + current_row = this_row + if cell["column header"]: + cell_tag = "th" + row = ET.SubElement(table, "thead") + else: + cell_tag = "td" + row = ET.SubElement(table, "tr") + tcell = ET.SubElement(row, cell_tag, attrib=attrib) + tcell.text = cell["cell text"] + + return str(ET.tostring(table, encoding="unicode", short_empty_elements=False))