diff --git a/CHANGELOG.md b/CHANGELOG.md
index d1add3a2..f6e59d30 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,6 @@
-## 0.2.13-dev0
+## 0.2.13
+* Add table processing
* Change OCR logic to be aware of PDF image elements
## 0.2.12
diff --git a/Makefile b/Makefile
index 09c498f5..3ac85a38 100644
--- a/Makefile
+++ b/Makefile
@@ -20,7 +20,7 @@ install-base: install-base-pip-packages
install: install-base-pip-packages install-dev install-detectron2 install-test
.PHONY: install-ci
-install-ci: install-base-pip-packages install-test
+install-ci: install-base-pip-packages install-test install-paddleocr
.PHONY: install-base-pip-packages
install-base-pip-packages:
@@ -31,6 +31,10 @@ install-base-pip-packages:
install-detectron2:
pip install "detectron2@git+https://github.com/facebookresearch/detectron2.git@78d5b4f335005091fe0364ce4775d711ec93566e"
+.PHONY: install-paddleocr
+install-paddleocr:
+ pip install "unstructured.PaddleOCR"
+
.PHONY: install-test
install-test:
pip install -r requirements/test.txt
diff --git a/README.md b/README.md
index fb4606b6..571b9588 100644
--- a/README.md
+++ b/README.md
@@ -34,6 +34,17 @@ Windows is not officially supported by Detectron2, but some users are able to in
See discussion [here](https://layout-parser.github.io/tutorials/installation#for-windows-users) for
tips on installing Detectron2 on Windows.
+### PaddleOCR
+
+[PaddleOCR](https://github.com/Unstructured-IO/unstructured.PaddleOCR) is required for table processing for `x86_64` architectures.
+It should not be installed under MacOS with Apple Silicon cpu.
+
+PaddleOCR should be installed using the following instructions.
+
+```shell
+pip install "unstructured.PaddleOCR"
+```
+
### Repository
To install the repository for development, clone the repo and run `make install` to install dependencies.
diff --git a/sample-docs/example_table.jpg b/sample-docs/example_table.jpg
new file mode 100644
index 00000000..c17ce7bf
Binary files /dev/null and b/sample-docs/example_table.jpg differ
diff --git a/setup.cfg b/setup.cfg
index 78525135..dd54ef0d 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -3,7 +3,7 @@ license_files = LICENSE.md
[flake8]
max-line-length = 100
-ignore = D100, D101, D104, D105, D107, D2, D4
+extend-ignore = D100, D101, D104, D105, D107, D2, D4
per-file-ignores =
test_*/**: D
diff --git a/setup.py b/setup.py
index bd19583f..9a32d82f 100644
--- a/setup.py
+++ b/setup.py
@@ -18,6 +18,7 @@
limitations under the License.
"""
from setuptools import setup, find_packages
+from platform import machine
from unstructured_inference.__version__ import __version__
@@ -60,5 +61,5 @@
"onnxruntime",
"transformers",
],
- extras_require={},
+ extras_require={"paddle-ocr": "unstructured.PaddleOCR"},
)
diff --git a/test_unstructured_inference/inference/test_layout.py b/test_unstructured_inference/inference/test_layout.py
index 6e322a39..4b4b4d7f 100644
--- a/test_unstructured_inference/inference/test_layout.py
+++ b/test_unstructured_inference/inference/test_layout.py
@@ -186,11 +186,12 @@ def points(self):
class MockPageLayout(layout.PageLayout):
- def __init__(self, layout=None, model=None, ocr_strategy="auto"):
+ def __init__(self, layout=None, model=None, ocr_strategy="auto", extract_tables=False):
self.image = None
self.layout = layout
self.model = model
self.ocr_strategy = ocr_strategy
+ self.extract_tables = extract_tables
def ocr(self, text_block: MockTextRegion):
return text_block.ocr_text
diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py
new file mode 100644
index 00000000..68e29cb5
--- /dev/null
+++ b/test_unstructured_inference/models/test_tables.py
@@ -0,0 +1,569 @@
+import pytest
+from unittest.mock import patch
+
+from transformers.models.table_transformer.modeling_table_transformer import TableTransformerDecoder
+
+import unstructured_inference.models.tables as tables
+import unstructured_inference.models.table_postprocess as postprocess
+
+
+@pytest.mark.parametrize(
+ "model_path",
+ [
+ ("invalid_table_path"),
+ ("incorrect_table_path"),
+ ],
+)
+def test_load_table_model_raises_when_not_available(model_path):
+ with pytest.raises(ImportError):
+ table_model = tables.UnstructuredTableTransformerModel()
+ table_model.initialize(model=model_path)
+
+
+@pytest.mark.parametrize(
+ "model_path",
+ [
+ "microsoft/table-transformer-structure-recognition",
+ ],
+)
+def test_load_donut_model(model_path):
+ table_model = tables.UnstructuredTableTransformerModel()
+ table_model.initialize(model=model_path)
+ assert type(table_model.model.model.decoder) == TableTransformerDecoder
+
+
+@pytest.fixture
+def sample_table_transcript(platform_type):
+ if platform_type == "x86_64":
+ out = (
+ '
| About these Coverage Examples: | '
+ ' | This is not a cost estimator. Treatments shown are just examples '
+ "of how this plan might cover medical care. Your actual costs will be different "
+ "depending on the actual care you receive, the prices your providers charge, and many "
+ "other factors. Focus on the cost sharing amounts (deductibles, copayments and "
+ "coinsurance) and excluded services under the plan. Use this information to compare "
+ "the portion of costs you might pay under different health plans. Please note these "
+ 'coverage examples are based on self-only coverage | '
+ "Peg is Having a Baby (9 months of in-network pre-natal care and a hospital delivery)"
+ "th> | Managing Joe's type 2 Diabetes (a year of routine in-network care of a well- "
+ 'controlled condition) | | Mia\'s Simple Fracture (in-network '
+ 'emergency room visit and follow up care) | | The plan\'s '
+ "overall deductible $750 Specialist copayment $50 Hospital (facility) coinsurance 10"
+ "% Other coinsurance 10% | The plan's overall deductible Specialist copayment "
+ r"Hospital (facility) coinsurance Other coinsurance | $750 $50 10% 10% | "
+ "The plan's overall deductible Specialist copayment Hospital (facility) coinsurance "
+ r'Other coinsurance | $750 $50 10% 10% |
| '
+ "This EXAMPLE event includes services like: Specialist office visits (prenatal care) "
+ "Childbirth/Delivery Professional Services Childbirth/Delivery Facility Services | <"
+ 'td colspan="2" rowspan="2">This EXAMPLE event includes services like: Primary care '
+ "physician office visits (including disease education) Diagnostic tests (blood work) "
+ 'Prescription drugs Durable medical equipment (glucose meter)This EXAMPLE event includes services like: Emergency room care (including '
+ "medical Diagnostic test (x-ray) Durable medical equipment (crutches) Rehabilitation "
+ 'services (physical therapy) |
| Diagnostic tests ('
+ "ultrasounds and blood work) Specialist visit (anesthesia) |
| Total "
+ "Example Cost | $12,700 | Total Example Cost | $5,600 | Total "
+ 'Example Cost | $2,800 |
| In this example, Peg would '
+ 'pay: | In this example, Joe would pay: | In this example, Mia '
+ 'would pay: |
| Cost Sharing | | Cost Sharing | Cost Sharing | |
| Deductibles | $750 | "
+ "Deductibles | $120 | Deductibles | $750 |
| Copayments"
+ ' | $30 | Copayments | $700 | '
+ 'Copayments $400 Coinsurance $30 |
| Coinsurance $'
+ "1,200 What isn't covered | Coinsurance | $0 |
| What isn't "
+ "covered | | What isn't covered | |
| Limits or "
+ "exclusions | $20 | Limits or exclusions | $20 | Limits or "
+ "exclusions | $0 |
| The total Peg would pay is | $2,000 | <"
+ 'td>The total Joe would pay is$840 | The total Mia would '
+ "pay is $1,180 |
| Plan Name: NVIDIA PPO PlanPIan ID: 14603022 | "
+ "td> | The plan would be responsible for the other costs of these EXAMPLE covered "
+ "services | | | Page 8 of 8 |
"
+ )
+ else:
+ out = (
+ '| About these Coverage Examples: | '
+ "This is not a cost depending on the (deductibles, pay under different | estimator. |reatments shown are just examples of how this plan might '
+ "cover medical care. Your actual costs will be different actual care you receive, the "
+ "prices your providers charge, and many other factors. Focus on the cost sharing "
+ "amounts copayments and coinsurance) and excluded services under the plan. Use this "
+ "information to compare the portion of costs you might health plans. Please note these "
+ 'coverage examples are based on self-only coverage. | '
+ "Peg is Having a Baby (9 months of in-network pre-natal care and a hospital delivery)"
+ "th> | Managing Joe's type 2 (a year of routine in-network care controlled conaition"
+ ') | Diabetes of a well- | Mia\'s Simple Fracture (in-network '
+ 'emergency room visit and follow up care) | | = The plan'
+ "'s overall deductible $750 = Specialist copayment $50 = Hospital (facility) "
+ "coinsurance 10% = Other coinsurance 10% | = The plan's overall deductible = "
+ "Specialist copayment = Hospital (facility) coinsurance = Other coinsurance | $"
+ r"750 $50 10% 10% | = The plan's overall deductible = Specialist copayment = "
+ r"Hospital (facility) coinsurance = Other coinsurance | $750 $50 10% 10% |
| This EXAMPLE event includes services like: '
+ "specialist office visits (prenatal care) Childbirth/Delivery Professional Services "
+ 'Childbirth/Delivery Facility Services | This EXAMPLE '
+ "event includes services like: Primary care physician office visits (including aisease "
+ "education) Diagnostic tests (b/ood work) Prescription drugs Durable medical equipment "
+ '(/g/ucose meter) | This EXAMPLE event includes services '
+ "like: Emergency room care (including meaical suoplies) Diagnostic test (x-ray) "
+ "Durable medical equipment (crutches) Rehabilitation services (o/hysical therapy) |
| Diagnostic tests (u/trasounas and blood work) specialist '
+ "visit (anesthesia) |
| Total Example Cost | | $12,700 | "
+ "Total Example Cost | | $5,600 | Total Example Cost | | $2,800 | "
+ 'tr>
| In this example, Peg would pay: | In this example, Joe '
+ 'would pay: | In this example, Mia would pay: |
| Cost '
+ 'Sharing | | Cost Sharing | Cost Sharing | | '
+ "tr>
| Deductibles | $/50 | Deductibles | $120 | "
+ "Deductibles | $/50 |
| Copayments | $30 | Copayments"
+ 'td> | $/00 | Copayments $400 Coinsurance $30 |
<'
+ 'tr>Coinsurance $1,200 What isn t covered | '
+ "Coinsurance | | | What isnt covered | | What isnt "
+ "covered | |
| Limits or exclusions | $20 | Limits or "
+ "exclusions | | $20 | Limits or exclusions | |
| The "
+ "total Peg would pay is | $2,000 | The total Joe would pay is | 9840"
+ ' | The total Mia would pay is $1,180 |
| Plan Name: '
+ "NVIDIA PPO Plan | The plan would Plan ID: 14603022 | be responsible for "
+ "the other costs of these | EXAMPLE | covered services. | Page 8 of 8"
+ " |
"
+ )
+ return out
+
+
+@pytest.mark.parametrize(
+ "input_test, output_test",
+ [
+ (
+ [
+ {
+ "label": "table column header",
+ "score": 0.9349299073219299,
+ "bbox": [
+ 47.83147430419922,
+ 116.8877944946289,
+ 2557.79296875,
+ 216.98883056640625,
+ ],
+ },
+ {
+ "label": "table column header",
+ "score": 0.934,
+ "bbox": [
+ 47.83147430419922,
+ 116.8877944946289,
+ 2557.79296875,
+ 216.98883056640625,
+ ],
+ },
+ ],
+ [
+ {
+ "label": "table column header",
+ "score": 0.9349299073219299,
+ "bbox": [
+ 47.83147430419922,
+ 116.8877944946289,
+ 2557.79296875,
+ 216.98883056640625,
+ ],
+ }
+ ],
+ ),
+ ([], []),
+ ],
+)
+def test_nms(input_test, output_test):
+ output = postprocess.nms(input_test)
+
+ assert output == output_test
+
+
+@pytest.mark.parametrize(
+ "supercell1, supercell2",
+ [
+ (
+ {
+ "label": "table spanning cell",
+ "score": 0.526617169380188,
+ "bbox": [1446.2801513671875, 1023.817138671875, 2114.3525390625, 1099.20166015625],
+ "projected row header": False,
+ "header": False,
+ "row_numbers": [3, 4],
+ "column_numbers": [0, 4],
+ },
+ {
+ "label": "table spanning cell",
+ "score": 0.5199193954467773,
+ "bbox": [
+ 98.92312622070312,
+ 676.1566772460938,
+ 751.0982666015625,
+ 938.5986938476562,
+ ],
+ "projected row header": False,
+ "header": False,
+ "row_numbers": [3, 4, 6],
+ "column_numbers": [0, 4],
+ },
+ ),
+ (
+ {
+ "label": "table spanning cell",
+ "score": 0.526617169380188,
+ "bbox": [1446.2801513671875, 1023.817138671875, 2114.3525390625, 1099.20166015625],
+ "projected row header": False,
+ "header": False,
+ "row_numbers": [3, 4],
+ "column_numbers": [0, 4],
+ },
+ {
+ "label": "table spanning cell",
+ "score": 0.5199193954467773,
+ "bbox": [
+ 98.92312622070312,
+ 676.1566772460938,
+ 751.0982666015625,
+ 938.5986938476562,
+ ],
+ "projected row header": False,
+ "header": False,
+ "row_numbers": [4],
+ "column_numbers": [0, 4, 6],
+ },
+ ),
+ (
+ {
+ "label": "table spanning cell",
+ "score": 0.526617169380188,
+ "bbox": [1446.2801513671875, 1023.817138671875, 2114.3525390625, 1099.20166015625],
+ "projected row header": False,
+ "header": False,
+ "row_numbers": [3, 4],
+ "column_numbers": [1, 4],
+ },
+ {
+ "label": "table spanning cell",
+ "score": 0.5199193954467773,
+ "bbox": [
+ 98.92312622070312,
+ 676.1566772460938,
+ 751.0982666015625,
+ 938.5986938476562,
+ ],
+ "projected row header": False,
+ "header": False,
+ "row_numbers": [4],
+ "column_numbers": [0, 4, 6],
+ },
+ ),
+ (
+ {
+ "label": "table spanning cell",
+ "score": 0.526617169380188,
+ "bbox": [1446.2801513671875, 1023.817138671875, 2114.3525390625, 1099.20166015625],
+ "projected row header": False,
+ "header": False,
+ "row_numbers": [3, 4],
+ "column_numbers": [1, 4],
+ },
+ {
+ "label": "table spanning cell",
+ "score": 0.5199193954467773,
+ "bbox": [
+ 98.92312622070312,
+ 676.1566772460938,
+ 751.0982666015625,
+ 938.5986938476562,
+ ],
+ "projected row header": False,
+ "header": False,
+ "row_numbers": [2, 4, 5, 6, 7, 8],
+ "column_numbers": [0, 4, 6],
+ },
+ ),
+ ],
+)
+def test_remove_supercell_overlap(supercell1, supercell2):
+ assert postprocess.remove_supercell_overlap(supercell1, supercell2) is None
+
+
+@pytest.mark.parametrize(
+ ("supercells", "rows", "columns", "output_test"),
+ [
+ (
+ [
+ {
+ "label": "table spanning cell",
+ "score": 0.9,
+ "bbox": [
+ 98.92312622070312,
+ 143.11549377441406,
+ 2115.197265625,
+ 1238.27587890625,
+ ],
+ "projected row header": True,
+ "header": True,
+ "span": True,
+ },
+ ],
+ [
+ {
+ "label": "table row",
+ "score": 0.9299452900886536,
+ "bbox": [0, 0, 10, 10],
+ "column header": True,
+ "header": True,
+ },
+ {
+ "label": "table row",
+ "score": 0.9299452900886536,
+ "bbox": [
+ 98.92312622070312,
+ 143.11549377441406,
+ 2114.3525390625,
+ 193.67681884765625,
+ ],
+ "column header": True,
+ "header": True,
+ },
+ {
+ "label": "table row",
+ "score": 0.9299452900886536,
+ "bbox": [
+ 98.92312622070312,
+ 143.11549377441406,
+ 2114.3525390625,
+ 193.67681884765625,
+ ],
+ "column header": True,
+ "header": True,
+ },
+ ],
+ [
+ {
+ "label": "table column",
+ "score": 0.9996132254600525,
+ "bbox": [
+ 98.92312622070312,
+ 143.11549377441406,
+ 517.6508178710938,
+ 1616.48779296875,
+ ],
+ },
+ {
+ "label": "table column",
+ "score": 0.9935646653175354,
+ "bbox": [
+ 520.0474853515625,
+ 143.11549377441406,
+ 751.0982666015625,
+ 1616.48779296875,
+ ],
+ },
+ ],
+ [
+ {
+ "label": "table spanning cell",
+ "score": 0.9,
+ "bbox": [
+ 98.92312622070312,
+ 143.11549377441406,
+ 751.0982666015625,
+ 193.67681884765625,
+ ],
+ "projected row header": True,
+ "header": True,
+ "span": True,
+ "row_numbers": [1, 2],
+ "column_numbers": [0, 1],
+ },
+ {
+ "row_numbers": [0],
+ "column_numbers": [0, 1],
+ "score": 0.9,
+ "propagated": True,
+ "bbox": [
+ 98.92312622070312,
+ 143.11549377441406,
+ 751.0982666015625,
+ 193.67681884765625,
+ ],
+ },
+ ],
+ )
+ ],
+)
+def test_align_supercells(supercells, rows, columns, output_test):
+ assert postprocess.align_supercells(supercells, rows, columns) == output_test
+
+
+@pytest.mark.parametrize("rows, bbox, output", [([1.0], [0.0], [1.0])])
+def test_align_rows(rows, bbox, output):
+ assert postprocess.align_rows(rows, bbox) == output
+
+
+@pytest.mark.parametrize(
+ ("model_path", "platform_type"),
+ [
+ ("microsoft/table-transformer-structure-recognition", "arm64"),
+ ("microsoft/table-transformer-structure-recognition", "x86_64"),
+ ],
+)
+def test_table_prediction(model_path, sample_table_transcript, platform_type):
+ with patch("platform.machine", return_value=platform_type):
+ table_model = tables.UnstructuredTableTransformerModel()
+ from PIL import Image
+
+ table_model.initialize(model=model_path)
+ img = Image.open("./sample-docs/example_table.jpg").convert("RGB")
+ prediction = table_model.predict(img)
+ assert prediction == sample_table_transcript
+
+
+def test_intersect():
+ a = postprocess.Rect()
+ b = postprocess.Rect([1, 2, 3, 4])
+ assert a.intersect(b).get_area() == 4.0
+
+
+def test_include_rect():
+ a = postprocess.Rect()
+ assert a.include_rect([1, 2, 3, 4]).get_area() == 4.0
+
+
+@pytest.mark.parametrize(
+ ("spans", "join_with_space", "expected"),
+ [
+ (
+ [
+ {
+ "flags": 2**0,
+ "text": "5",
+ "superscript": False,
+ "span_num": 0,
+ "line_num": 0,
+ "block_num": 0,
+ }
+ ],
+ True,
+ "",
+ ),
+ (
+ [
+ {
+ "flags": 2**0,
+ "text": "p",
+ "superscript": False,
+ "span_num": 0,
+ "line_num": 0,
+ "block_num": 0,
+ }
+ ],
+ True,
+ "p",
+ ),
+ (
+ [
+ {
+ "flags": 2**0,
+ "text": "p",
+ "superscript": False,
+ "span_num": 0,
+ "line_num": 0,
+ "block_num": 0,
+ },
+ {
+ "flags": 2**0,
+ "text": "p",
+ "superscript": False,
+ "span_num": 0,
+ "line_num": 0,
+ "block_num": 0,
+ },
+ ],
+ True,
+ "p p",
+ ),
+ (
+ [
+ {
+ "flags": 2**0,
+ "text": "p",
+ "superscript": False,
+ "span_num": 0,
+ "line_num": 0,
+ "block_num": 0,
+ },
+ {
+ "flags": 2**0,
+ "text": "p",
+ "superscript": False,
+ "span_num": 0,
+ "line_num": 0,
+ "block_num": 1,
+ },
+ ],
+ True,
+ "p p",
+ ),
+ (
+ [
+ {
+ "flags": 2**0,
+ "text": "p",
+ "superscript": False,
+ "span_num": 0,
+ "line_num": 0,
+ "block_num": 0,
+ },
+ {
+ "flags": 2**0,
+ "text": "p",
+ "superscript": False,
+ "span_num": 0,
+ "line_num": 0,
+ "block_num": 1,
+ },
+ ],
+ False,
+ "p p",
+ ),
+ ],
+)
+def test_extract_text_from_spans(spans, join_with_space, expected):
+ res = postprocess.extract_text_from_spans(
+ spans, join_with_space=join_with_space, remove_integer_superscripts=True
+ )
+ assert res == expected
+
+
+@pytest.mark.parametrize(
+ ("supercells", "expected_len"),
+ [
+ ([{"header": "hi", "row_numbers": [0, 1, 2], "score": 0.9}], 1),
+ (
+ [
+ {
+ "header": "hi",
+ "row_numbers": [0],
+ "column_numbers": [1, 2, 3],
+ "score": 0.9,
+ },
+ {"header": "hi", "row_numbers": [1], "column_numbers": [1], "score": 0.9},
+ {"header": "hi", "row_numbers": [1], "column_numbers": [2], "score": 0.9},
+ {"header": "hi", "row_numbers": [1], "column_numbers": [3], "score": 0.9},
+ ],
+ 4,
+ ),
+ (
+ [
+ {"header": "hi", "row_numbers": [0], "column_numbers": [0], "score": 0.9},
+ {"header": "hi", "row_numbers": [1], "column_numbers": [0], "score": 0.9},
+ {"header": "hi", "row_numbers": [1, 2], "column_numbers": [0], "score": 0.9},
+ {"header": "hi", "row_numbers": [3], "column_numbers": [0], "score": 0.9},
+ ],
+ 3,
+ ),
+ ],
+)
+def test_header_supercell_tree(supercells, expected_len):
+ postprocess.header_supercell_tree(supercells)
+ assert len(supercells) == expected_len
diff --git a/test_unstructured_inference/models/test_yolox.py b/test_unstructured_inference/models/test_yolox.py
index ab749b54..b03565d0 100644
--- a/test_unstructured_inference/models/test_yolox.py
+++ b/test_unstructured_inference/models/test_yolox.py
@@ -22,7 +22,7 @@ def test_layout_yolox_local_parsing_pdf():
filename = os.path.join("sample-docs", "loremipsum.pdf")
document_layout = process_file_with_model(filename, model_name="yolox")
content = str(document_layout)
- assert "Lorem ipsum" in content
+ assert "libero fringilla" in content
assert len(document_layout.pages) == 1
# NOTE(benjamin) The example sent to the test contains 5 detections
assert len(document_layout.pages[0].elements) == 5
@@ -57,7 +57,7 @@ def test_layout_yolox_local_parsing_pdf_soft():
filename = os.path.join("sample-docs", "loremipsum.pdf")
document_layout = process_file_with_model(filename, model_name="yolox_tiny")
content = str(document_layout)
- assert "Lorem ipsum" in content
+ assert "libero fringilla" in content
assert len(document_layout.pages) == 1
# NOTE(benjamin) Soft version of the test, run make test-long in order to run with full model
assert len(document_layout.pages[0].elements) > 0
diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py
index b42b08af..53cb2730 100644
--- a/unstructured_inference/__version__.py
+++ b/unstructured_inference/__version__.py
@@ -1 +1 @@
-__version__ = "0.2.13-dev0" # pragma: no cover
+__version__ = "0.2.13" # pragma: no cover
diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py
index cc7f1cfb..297c7cdf 100644
--- a/unstructured_inference/inference/layout.py
+++ b/unstructured_inference/inference/layout.py
@@ -14,6 +14,7 @@
from unstructured_inference.inference.elements import TextRegion, ImageTextRegion, LayoutElement
from unstructured_inference.logger import logger
import unstructured_inference.models.tesseract as tesseract
+import unstructured_inference.models.tables as tables
from unstructured_inference.models.base import get_model
from unstructured_inference.models.unstructuredmodel import UnstructuredModel
@@ -54,6 +55,7 @@ def from_file(
model: Optional[UnstructuredModel] = None,
fixed_layouts: Optional[List[Optional[List[TextRegion]]]] = None,
ocr_strategy: str = "auto",
+ extract_tables: bool = False,
) -> DocumentLayout:
"""Creates a DocumentLayout from a pdf file."""
logger.info(f"Reading PDF for file: {filename} ...")
@@ -74,6 +76,7 @@ def from_file(
layout=layout,
ocr_strategy=ocr_strategy,
fixed_layout=fixed_layout,
+ extract_tables=extract_tables,
)
pages.append(page)
return cls.from_pages(pages)
@@ -85,6 +88,7 @@ def from_image_file(
model: Optional[UnstructuredModel] = None,
ocr_strategy: str = "auto",
fixed_layout: Optional[List[TextRegion]] = None,
+ extract_tables: bool = False,
) -> DocumentLayout:
"""Creates a DocumentLayout from an image file."""
logger.info(f"Reading image file: {filename} ...")
@@ -96,7 +100,12 @@ def from_image_file(
else:
raise FileNotFoundError(f'File "{filename}" not found!') from e
page = PageLayout.from_image(
- image, model=model, layout=None, ocr_strategy=ocr_strategy, fixed_layout=fixed_layout
+ image,
+ model=model,
+ layout=None,
+ ocr_strategy=ocr_strategy,
+ fixed_layout=fixed_layout,
+ extract_tables=extract_tables,
)
return cls.from_pages([page])
@@ -111,6 +120,7 @@ def __init__(
layout: Optional[List[TextRegion]],
model: Optional[UnstructuredModel] = None,
ocr_strategy: str = "auto",
+ extract_tables: bool = False,
):
self.image = image
self.image_array: Union[np.ndarray, None] = None
@@ -121,6 +131,7 @@ def __init__(
if ocr_strategy not in VALID_OCR_STRATEGIES:
raise ValueError(f"ocr_strategy must be one of {VALID_OCR_STRATEGIES}.")
self.ocr_strategy = ocr_strategy
+ self.extract_tables = extract_tables
def __str__(self) -> str:
return "\n\n".join([str(element) for element in self.elements])
@@ -148,7 +159,11 @@ def get_elements_from_layout(self, layout: List[TextRegion]) -> List[LayoutEleme
layout.sort(key=lambda element: element.y1)
elements = []
for e in tqdm(layout):
- elements.append(get_element_from_block(e, self.image, self.layout, self.ocr_strategy))
+ elements.append(
+ get_element_from_block(
+ e, self.image, self.layout, self.ocr_strategy, self.extract_tables
+ )
+ )
return elements
def _get_image_array(self) -> Union[np.ndarray, None]:
@@ -164,10 +179,18 @@ def from_image(
model: Optional[UnstructuredModel] = None,
layout: Optional[List[TextRegion]] = None,
ocr_strategy: str = "auto",
+ extract_tables: bool = False,
fixed_layout: Optional[List[TextRegion]] = None,
):
"""Creates a PageLayout from an already-loaded PIL Image."""
- page = cls(number=0, image=image, layout=layout, model=model, ocr_strategy=ocr_strategy)
+ page = cls(
+ number=0,
+ image=image,
+ layout=layout,
+ model=model,
+ ocr_strategy=ocr_strategy,
+ extract_tables=extract_tables,
+ )
if fixed_layout is None:
page.get_elements()
else:
@@ -239,12 +262,15 @@ def get_element_from_block(
image: Optional[Image.Image] = None,
pdf_objects: Optional[List[Union[TextRegion, ImageTextRegion]]] = None,
ocr_strategy: str = "auto",
+ extract_tables: bool = False,
) -> LayoutElement:
"""Creates a LayoutElement from a given layout or image by finding all the text that lies within
a given block."""
if block.text is not None:
# If block text is already populated, we'll assume it's correct
text = block.text
+ elif extract_tables and isinstance(block, LayoutElement) and block.type == "Table":
+ text = interprete_table_block(block, image)
elif pdf_objects is not None:
text = aggregate_by_block(block, image, pdf_objects, ocr_strategy)
elif image is not None:
@@ -284,11 +310,21 @@ def aggregate_by_block(
return text
+def interprete_table_block(text_block: TextRegion, image: Image.Image) -> str:
+ """Extract the contents of a table."""
+ tables.load_agent()
+ if tables.tables_agent is None:
+ raise RuntimeError("Unable to load table extraction agent.")
+ padded_block = text_block.pad(12)
+ cropped_image = image.crop((padded_block.x1, padded_block.y1, padded_block.x2, padded_block.y2))
+ return tables.tables_agent.predict(cropped_image)
+
+
def ocr(text_block: TextRegion, image: Image.Image) -> str:
"""Runs a cropped text block image through and OCR agent."""
logger.debug("Running OCR on text block ...")
tesseract.load_agent()
- padded_block = text_block.pad(5)
+ padded_block = text_block.pad(12)
cropped_image = image.crop((padded_block.x1, padded_block.y1, padded_block.x2, padded_block.y2))
return tesseract.ocr_agent.detect(cropped_image)
@@ -309,7 +345,7 @@ def load_pdf(
vertical_ttb: bool = True, # Should vertical words be read top-to-bottom?
extra_attrs: Optional[List[str]] = None,
split_at_punctuation: Union[bool, str] = False,
- dpi: int = 72,
+ dpi: int = 200,
) -> Tuple[List[List[TextRegion]], List[Image.Image]]:
"""Loads the image and word objects from a pdf using pdfplumber and the image renderings of the
pdf pages using pdf2image"""
diff --git a/unstructured_inference/models/paddle_ocr.py b/unstructured_inference/models/paddle_ocr.py
new file mode 100644
index 00000000..3fee51e0
--- /dev/null
+++ b/unstructured_inference/models/paddle_ocr.py
@@ -0,0 +1,12 @@
+from unstructured_paddleocr import PaddleOCR
+
+paddle_ocr: PaddleOCR = None
+
+
+def load_agent():
+ """Loads the PaddleOCR agent as a global variable to ensure that we only load it once."""
+
+ global paddle_ocr
+ paddle_ocr = PaddleOCR(use_angle_cls=True, lang="en", mkl_dnn=True, show_log=False)
+
+ return paddle_ocr
diff --git a/unstructured_inference/models/table_postprocess.py b/unstructured_inference/models/table_postprocess.py
new file mode 100644
index 00000000..67a59a0c
--- /dev/null
+++ b/unstructured_inference/models/table_postprocess.py
@@ -0,0 +1,623 @@
+# https://github.com/microsoft/table-transformer/blob/main/src/postprocess.py
+"""
+Copyright (C) 2021 Microsoft Corporation
+"""
+from collections import defaultdict
+
+
+class Rect:
+ def __init__(self, bbox=None):
+ if bbox is None:
+ self.x_min = 0
+ self.y_min = 0
+ self.x_max = 0
+ self.y_max = 0
+ else:
+ self.x_min = bbox[0]
+ self.y_min = bbox[1]
+ self.x_max = bbox[2]
+ self.y_max = bbox[3]
+
+ def get_area(self):
+ """Calculates the area of the rectangle"""
+ area = (self.x_max - self.x_min) * (self.y_max - self.y_min)
+ return area if area > 0 else 0.0
+
+ def intersect(self, other):
+ """Calculates the intersection with another rectangle"""
+ if self.get_area() == 0:
+ self.x_min = other.x_min
+ self.y_min = other.y_min
+ self.x_max = other.x_max
+ self.y_max = other.y_max
+ else:
+ self.x_min = max(self.x_min, other.x_min)
+ self.y_min = max(self.y_min, other.y_min)
+ self.x_max = min(self.x_max, other.x_max)
+ self.y_max = min(self.y_max, other.y_max)
+
+ if self.x_min > self.x_max or self.y_min > self.y_max or self.get_area() == 0:
+ self.x_min = 0
+ self.y_min = 0
+ self.x_max = 0
+ self.y_max = 0
+
+ return self
+
+ def include_rect(self, bbox):
+ """Calculates a rectangle that includes both rectangles"""
+ other = Rect(bbox)
+
+ if self.get_area() == 0:
+ self.x_min = other.x_min
+ self.y_min = other.y_min
+ self.x_max = other.x_max
+ self.y_max = other.y_max
+ return self
+
+ self.x_min = min(self.x_min, other.x_min)
+ self.y_min = min(self.y_min, other.y_min)
+ self.x_max = max(self.x_max, other.x_max)
+ self.y_max = max(self.y_max, other.y_max)
+
+ # if self.get_area() == 0:
+ # self.x_min = other.x_min
+ # self.y_min = other.y_min
+ # self.x_max = other.x_max
+ # self.y_max = other.y_max
+
+ return self
+
+ def get_bbox(self):
+ """Returns the coordinates that define the rectangle"""
+ return [self.x_min, self.y_min, self.x_max, self.y_max]
+
+
+def apply_threshold(objects, threshold):
+ """
+ Filter out objects below a certain score.
+ """
+ return [obj for obj in objects if obj["score"] >= threshold]
+
+
+# def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds):
+# """
+# Filter out bounding boxes whose confidence is below the confidence threshold for
+# its associated class label.
+# """
+# # Apply class-specific thresholds
+# indices_above_threshold = [
+# idx
+# for idx, (score, label) in enumerate(zip(scores, labels))
+# if score >= class_thresholds[class_names[label]]
+# ]
+# bboxes = [bboxes[idx] for idx in indices_above_threshold]
+# scores = [scores[idx] for idx in indices_above_threshold]
+# labels = [labels[idx] for idx in indices_above_threshold]
+
+# return bboxes, scores, labels
+
+
+def refine_rows(rows, tokens, score_threshold):
+ """
+ Apply operations to the detected rows, such as
+ thresholding, NMS, and alignment.
+ """
+
+ if len(tokens) > 0:
+ rows = nms_by_containment(rows, tokens, overlap_threshold=0.5)
+ remove_objects_without_content(tokens, rows)
+ else:
+ rows = nms(rows, match_criteria="object2_overlap", match_threshold=0.5, keep_higher=True)
+ if len(rows) > 1:
+ rows = sort_objects_top_to_bottom(rows)
+
+ return rows
+
+
+def refine_columns(columns, tokens, score_threshold):
+ """
+ Apply operations to the detected columns, such as
+ thresholding, NMS, and alignment.
+ """
+
+ if len(tokens) > 0:
+ columns = nms_by_containment(columns, tokens, overlap_threshold=0.5)
+ remove_objects_without_content(tokens, columns)
+ else:
+ columns = nms(
+ columns, match_criteria="object2_overlap", match_threshold=0.25, keep_higher=True
+ )
+ if len(columns) > 1:
+ columns = sort_objects_left_to_right(columns)
+
+ return columns
+
+
+def nms_by_containment(container_objects, package_objects, overlap_threshold=0.5):
+ """
+ Non-maxima suppression (NMS) of objects based on shared containment of other objects.
+ """
+ container_objects = sort_objects_by_score(container_objects)
+ num_objects = len(container_objects)
+ suppression = [False for obj in container_objects]
+
+ packages_by_container, _, _ = slot_into_containers(
+ container_objects,
+ package_objects,
+ overlap_threshold=overlap_threshold,
+ forced_assignment=False,
+ )
+
+ for object2_num in range(1, num_objects):
+ object2_packages = set(packages_by_container[object2_num])
+ if len(object2_packages) == 0:
+ suppression[object2_num] = True
+ for object1_num in range(object2_num):
+ if not suppression[object1_num]:
+ object1_packages = set(packages_by_container[object1_num])
+ if len(object2_packages.intersection(object1_packages)) > 0:
+ suppression[object2_num] = True
+
+ final_objects = [obj for idx, obj in enumerate(container_objects) if not suppression[idx]]
+ return final_objects
+
+
+def slot_into_containers(
+ container_objects,
+ package_objects,
+ overlap_threshold=0.5,
+ forced_assignment=False,
+):
+ """
+ Slot a collection of objects into the container they occupy most (the container which holds the
+ largest fraction of the object).
+ """
+ best_match_scores = []
+
+ container_assignments = [[] for container in container_objects]
+ package_assignments = [[] for package in package_objects]
+
+ if len(container_objects) == 0 or len(package_objects) == 0:
+ return container_assignments, package_assignments, best_match_scores
+
+ match_scores = defaultdict(dict)
+ for package_num, package in enumerate(package_objects):
+ match_scores = []
+ package_rect = Rect(package["bbox"])
+ package_area = package_rect.get_area()
+ for container_num, container in enumerate(container_objects):
+ container_rect = Rect(container["bbox"])
+ intersect_area = container_rect.intersect(Rect(package["bbox"])).get_area()
+ overlap_fraction = intersect_area / package_area
+
+ match_scores.append(
+ {"container": container, "container_num": container_num, "score": overlap_fraction}
+ )
+
+ sorted_match_scores = sort_objects_by_score(match_scores)
+
+ best_match_score = sorted_match_scores[0]
+ best_match_scores.append(best_match_score["score"])
+ if forced_assignment or best_match_score["score"] >= overlap_threshold:
+ container_assignments[best_match_score["container_num"]].append(package_num)
+ package_assignments[package_num].append(best_match_score["container_num"])
+
+ return container_assignments, package_assignments, best_match_scores
+
+
+def sort_objects_by_score(objects, reverse=True):
+ """
+ Put any set of objects in order from high score to low score.
+ """
+ return sorted(objects, key=lambda k: k["score"], reverse=reverse)
+
+
+def remove_objects_without_content(page_spans, objects):
+ """
+ Remove any objects (these can be rows, columns, supercells, etc.) that don't
+ have any text associated with them.
+ """
+ for obj in objects[:]:
+ object_text, _ = extract_text_inside_bbox(page_spans, obj["bbox"])
+ if len(object_text.strip()) == 0:
+ objects.remove(obj)
+
+
+def extract_text_inside_bbox(spans, bbox):
+ """
+ Extract the text inside a bounding box.
+ """
+ bbox_spans = get_bbox_span_subset(spans, bbox)
+ bbox_text = extract_text_from_spans(bbox_spans, remove_integer_superscripts=True)
+
+ return bbox_text, bbox_spans
+
+
+def get_bbox_span_subset(spans, bbox, threshold=0.5):
+ """
+ Reduce the set of spans to those that fall within a bounding box.
+
+ threshold: the fraction of the span that must overlap with the bbox.
+ """
+ span_subset = []
+ for span in spans:
+ if overlaps(span["bbox"], bbox, threshold):
+ span_subset.append(span)
+ return span_subset
+
+
+def overlaps(bbox1, bbox2, threshold=0.5):
+ """
+ Test if more than "threshold" fraction of bbox1 overlaps with bbox2.
+ """
+ rect1 = Rect(list(bbox1))
+ area1 = rect1.get_area()
+ if area1 == 0:
+ return False
+ return rect1.intersect(Rect(list(bbox2))).get_area() / area1 >= threshold
+
+
+def extract_text_from_spans(spans, join_with_space=True, remove_integer_superscripts=True):
+ """
+ Convert a collection of page tokens/words/spans into a single text string.
+ """
+
+ join_char = " " if join_with_space else ""
+ spans_copy = spans[:]
+
+ if remove_integer_superscripts:
+ for span in spans:
+ if "flags" not in span:
+ continue
+ flags = span["flags"]
+ if flags & 2**0: # superscript flag
+ if span["text"].strip().isdigit():
+ spans_copy.remove(span)
+ else:
+ span["superscript"] = True
+
+ if len(spans_copy) == 0:
+ return ""
+
+ spans_copy.sort(key=lambda span: span["span_num"])
+ spans_copy.sort(key=lambda span: span["line_num"])
+ spans_copy.sort(key=lambda span: span["block_num"])
+
+ # Force the span at the end of every line within a block to have exactly one space
+ # unless the line ends with a space or ends with a non-space followed by a hyphen
+ line_texts = []
+ line_span_texts = [spans_copy[0]["text"]]
+ for span1, span2 in zip(spans_copy[:-1], spans_copy[1:]):
+ if (
+ not span1["block_num"] == span2["block_num"]
+ or not span1["line_num"] == span2["line_num"]
+ ):
+ line_text = join_char.join(line_span_texts).strip()
+ if (
+ len(line_text) > 0
+ and not line_text[-1] == " "
+ and not (len(line_text) > 1 and line_text[-1] == "-" and not line_text[-2] == " ")
+ ):
+ if not join_with_space:
+ line_text += " "
+ line_texts.append(line_text)
+ line_span_texts = [span2["text"]]
+ else:
+ line_span_texts.append(span2["text"])
+ line_text = join_char.join(line_span_texts)
+ line_texts.append(line_text)
+
+ return join_char.join(line_texts).strip()
+
+
+def sort_objects_left_to_right(objs):
+ """
+ Put the objects in order from left to right.
+ """
+ return sorted(objs, key=lambda k: k["bbox"][0] + k["bbox"][2])
+
+
+def sort_objects_top_to_bottom(objs):
+ """
+ Put the objects in order from top to bottom.
+ """
+ return sorted(objs, key=lambda k: k["bbox"][1] + k["bbox"][3])
+
+
+def align_columns(columns, bbox):
+ """
+ For every column, align the top and bottom boundaries to the final
+ table bounding box.
+ """
+ try:
+ for column in columns:
+ column["bbox"][1] = bbox[1]
+ column["bbox"][3] = bbox[3]
+ except Exception as err:
+ print("Could not align columns: {}".format(err))
+ pass
+
+ return columns
+
+
+def align_rows(rows, bbox):
+ """
+ For every row, align the left and right boundaries to the final
+ table bounding box.
+ """
+ try:
+ for row in rows:
+ row["bbox"][0] = bbox[0]
+ row["bbox"][2] = bbox[2]
+ except Exception as err:
+ print("Could not align rows: {}".format(err))
+ pass
+
+ return rows
+
+
+def nms(objects, match_criteria="object2_overlap", match_threshold=0.05, keep_higher=True):
+ """
+ A customizable version of non-maxima suppression (NMS).
+
+ Default behavior: If a lower-confidence object overlaps more than 5% of its area
+ with a higher-confidence object, remove the lower-confidence object.
+
+ objects: set of dicts; each object dict must have a 'bbox' and a 'score' field
+ match_criteria: how to measure how much two objects "overlap"
+ match_threshold: the cutoff for determining that overlap requires suppression of one object
+ keep_higher: if True, keep the object with the higher metric; otherwise, keep the lower
+ """
+ if len(objects) == 0:
+ return []
+
+ objects = sort_objects_by_score(objects, reverse=keep_higher)
+
+ num_objects = len(objects)
+ suppression = [False for obj in objects]
+
+ for object2_num in range(1, num_objects):
+ object2_rect = Rect(objects[object2_num]["bbox"])
+ object2_area = object2_rect.get_area()
+ for object1_num in range(object2_num):
+ if not suppression[object1_num]:
+ object1_rect = Rect(objects[object1_num]["bbox"])
+ object1_area = object1_rect.get_area()
+ intersect_area = object1_rect.intersect(object2_rect).get_area()
+ try:
+ if match_criteria == "object1_overlap":
+ metric = intersect_area / object1_area
+ elif match_criteria == "object2_overlap":
+ metric = intersect_area / object2_area
+ elif match_criteria == "iou":
+ metric = intersect_area / (object1_area + object2_area - intersect_area)
+ if metric >= match_threshold:
+ suppression[object2_num] = True
+ break
+ except ZeroDivisionError:
+ # Intended to recover from divide-by-zero
+ pass
+
+ return [obj for idx, obj in enumerate(objects) if not suppression[idx]]
+
+
+def align_supercells(supercells, rows, columns):
+ """
+ For each supercell, align it to the rows it intersects 50% of the height of,
+ and the columns it intersects 50% of the width of.
+ Eliminate supercells for which there are no rows and columns it intersects 50% with.
+ """
+ aligned_supercells = []
+
+ for supercell in supercells:
+ supercell["header"] = False
+ row_bbox_rect = None
+ col_bbox_rect = None
+ intersecting_header_rows = set()
+ intersecting_data_rows = set()
+ for row_num, row in enumerate(rows):
+ row_height = row["bbox"][3] - row["bbox"][1]
+ supercell_height = supercell["bbox"][3] - supercell["bbox"][1]
+ min_row_overlap = max(row["bbox"][1], supercell["bbox"][1])
+ max_row_overlap = min(row["bbox"][3], supercell["bbox"][3])
+ overlap_height = max_row_overlap - min_row_overlap
+ if "span" in supercell:
+ overlap_fraction = max(
+ overlap_height / row_height, overlap_height / supercell_height
+ )
+ else:
+ overlap_fraction = overlap_height / row_height
+ if overlap_fraction >= 0.5:
+ if "header" in row and row["header"]:
+ intersecting_header_rows.add(row_num)
+ else:
+ intersecting_data_rows.add(row_num)
+
+ # Supercell cannot span across the header boundary; eliminate whichever
+ # group of rows is the smallest
+ supercell["header"] = False
+ if len(intersecting_data_rows) > 0 and len(intersecting_header_rows) > 0:
+ if len(intersecting_data_rows) > len(intersecting_header_rows):
+ intersecting_header_rows = set()
+ else:
+ intersecting_data_rows = set()
+ if len(intersecting_header_rows) > 0:
+ supercell["header"] = True
+ elif "span" in supercell:
+ continue # Require span supercell to be in the header
+ intersecting_rows = intersecting_data_rows.union(intersecting_header_rows)
+ # Determine vertical span of aligned supercell
+ for row_num in intersecting_rows:
+ if row_bbox_rect is None:
+ row_bbox_rect = Rect(rows[row_num]["bbox"])
+ else:
+ row_bbox_rect = row_bbox_rect.include_rect(rows[row_num]["bbox"])
+ if row_bbox_rect is None:
+ continue
+
+ intersecting_cols = []
+ for col_num, col in enumerate(columns):
+ col_width = col["bbox"][2] - col["bbox"][0]
+ supercell_width = supercell["bbox"][2] - supercell["bbox"][0]
+ min_col_overlap = max(col["bbox"][0], supercell["bbox"][0])
+ max_col_overlap = min(col["bbox"][2], supercell["bbox"][2])
+ overlap_width = max_col_overlap - min_col_overlap
+ if "span" in supercell:
+ overlap_fraction = max(overlap_width / col_width, overlap_width / supercell_width)
+ # Multiply by 2 effectively lowers the threshold to 0.25
+ if supercell["header"]:
+ overlap_fraction = overlap_fraction * 2
+ else:
+ overlap_fraction = overlap_width / col_width
+ if overlap_fraction >= 0.5:
+ intersecting_cols.append(col_num)
+ if col_bbox_rect is None:
+ col_bbox_rect = Rect(col["bbox"])
+ else:
+ col_bbox_rect = col_bbox_rect.include_rect(col["bbox"])
+ if col_bbox_rect is None:
+ continue
+
+ supercell_bbox = row_bbox_rect.intersect(col_bbox_rect).get_bbox()
+ supercell["bbox"] = supercell_bbox
+
+ # Only a true supercell if it joins across multiple rows or columns
+ if (
+ len(intersecting_rows) > 0
+ and len(intersecting_cols) > 0
+ and (len(intersecting_rows) > 1 or len(intersecting_cols) > 1)
+ ):
+ supercell["row_numbers"] = list(intersecting_rows)
+ supercell["column_numbers"] = intersecting_cols
+ aligned_supercells.append(supercell)
+
+ # A span supercell in the header means there must be supercells above it in the header
+ if "span" in supercell and supercell["header"] and len(supercell["column_numbers"]) > 1:
+ for row_num in range(0, min(supercell["row_numbers"])):
+ new_supercell = {
+ "row_numbers": [row_num],
+ "column_numbers": supercell["column_numbers"],
+ "score": supercell["score"],
+ "propagated": True,
+ }
+ new_supercell_columns = [columns[idx] for idx in supercell["column_numbers"]]
+ new_supercell_rows = [rows[idx] for idx in supercell["row_numbers"]]
+ bbox = [
+ min([column["bbox"][0] for column in new_supercell_columns]),
+ min([row["bbox"][1] for row in new_supercell_rows]),
+ max([column["bbox"][2] for column in new_supercell_columns]),
+ max([row["bbox"][3] for row in new_supercell_rows]),
+ ]
+ new_supercell["bbox"] = bbox
+ aligned_supercells.append(new_supercell)
+
+ return aligned_supercells
+
+
+def nms_supercells(supercells):
+ """
+ A NMS scheme for supercells that first attempts to shrink supercells to
+ resolve overlap.
+ If two supercells overlap the same (sub)cell, shrink the lower confidence
+ supercell to resolve the overlap. If shrunk supercell is empty, remove it.
+ """
+
+ supercells = sort_objects_by_score(supercells)
+ num_supercells = len(supercells)
+ suppression = [False for supercell in supercells]
+
+ for supercell2_num in range(1, num_supercells):
+ supercell2 = supercells[supercell2_num]
+ for supercell1_num in range(supercell2_num):
+ supercell1 = supercells[supercell1_num]
+ remove_supercell_overlap(supercell1, supercell2)
+ if (
+ (len(supercell2["row_numbers"]) < 2 and len(supercell2["column_numbers"]) < 2)
+ or len(supercell2["row_numbers"]) == 0
+ or len(supercell2["column_numbers"]) == 0
+ ):
+ suppression[supercell2_num] = True
+
+ return [obj for idx, obj in enumerate(supercells) if not suppression[idx]]
+
+
+def header_supercell_tree(supercells):
+ """
+ Make sure no supercell in the header is below more than one supercell in any row above it.
+ The cells in the header form a tree, but a supercell with more than one supercell in a row
+ above it means that some cell has more than one parent, which is not allowed. Eliminate
+ any supercell that would cause this to be violated.
+ """
+ header_supercells = [
+ supercell for supercell in supercells if "header" in supercell and supercell["header"]
+ ]
+ header_supercells = sort_objects_by_score(header_supercells)
+
+ for header_supercell in header_supercells[:]:
+ ancestors_by_row = defaultdict(int)
+ min_row = min(header_supercell["row_numbers"])
+ for header_supercell2 in header_supercells:
+ max_row2 = max(header_supercell2["row_numbers"])
+ if max_row2 < min_row:
+ if set(header_supercell["column_numbers"]).issubset(
+ set(header_supercell2["column_numbers"])
+ ):
+ for row2 in header_supercell2["row_numbers"]:
+ ancestors_by_row[row2] += 1
+ for row in range(0, min_row):
+ if not ancestors_by_row[row] == 1:
+ supercells.remove(header_supercell)
+ break
+
+
+def remove_supercell_overlap(supercell1, supercell2):
+ """
+ This function resolves overlap between supercells (supercells must be
+ disjoint) by iteratively shrinking supercells by the fewest grid cells
+ necessary to resolve the overlap.
+ Example:
+ If two supercells overlap at grid cell (R, C), and supercell #1 is less
+ confident than supercell #2, we eliminate either row R from supercell #1
+ or column C from supercell #1 by comparing the number of columns in row R
+ versus the number of rows in column C. If the number of columns in row R
+ is less than the number of rows in column C, we eliminate row R from
+ supercell #1. This resolves the overlap by removing fewer grid cells from
+ supercell #1 than if we eliminated column C from it.
+ """
+ common_rows = set(supercell1["row_numbers"]).intersection(set(supercell2["row_numbers"]))
+ common_columns = set(supercell1["column_numbers"]).intersection(
+ set(supercell2["column_numbers"])
+ )
+
+ # While the supercells have overlapping grid cells, continue shrinking the less-confident
+ # supercell one row or one column at a time
+ while len(common_rows) > 0 and len(common_columns) > 0:
+ # Try to shrink the supercell as little as possible to remove the overlap;
+ # if the supercell has fewer rows than columns, remove an overlapping column,
+ # because this removes fewer grid cells from the supercell;
+ # otherwise remove an overlapping row
+ if len(supercell2["row_numbers"]) < len(supercell2["column_numbers"]):
+ min_column = min(supercell2["column_numbers"])
+ max_column = max(supercell2["column_numbers"])
+ if max_column in common_columns:
+ common_columns.remove(max_column)
+ supercell2["column_numbers"].remove(max_column)
+ elif min_column in common_columns:
+ common_columns.remove(min_column)
+ supercell2["column_numbers"].remove(min_column)
+ else:
+ supercell2["column_numbers"] = []
+ common_columns = set()
+ else:
+ min_row = min(supercell2["row_numbers"])
+ max_row = max(supercell2["row_numbers"])
+ if max_row in common_rows:
+ common_rows.remove(max_row)
+ supercell2["row_numbers"].remove(max_row)
+ elif min_row in common_rows:
+ common_rows.remove(min_row)
+ supercell2["row_numbers"].remove(min_row)
+ else:
+ supercell2["row_numbers"] = []
+ common_rows = set()
diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py
new file mode 100644
index 00000000..c90c359d
--- /dev/null
+++ b/unstructured_inference/models/tables.py
@@ -0,0 +1,623 @@
+# https://github.com/microsoft/table-transformer/blob/main/src/inference.py
+# https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Table%20Transformer/Using_Table_Transformer_for_table_detection_and_table_structure_recognition.ipynb
+import torch
+import logging
+
+from unstructured_inference.models.unstructuredmodel import UnstructuredModel
+from unstructured_inference.logger import logger
+
+from collections import defaultdict
+import xml.etree.ElementTree as ET
+
+import cv2
+import numpy as np
+import pandas as pd
+
+import pytesseract
+
+from transformers import TableTransformerForObjectDetection
+from transformers import DetrImageProcessor
+from PIL import Image
+from typing import Union, Optional
+from pathlib import Path
+import platform
+
+from . import table_postprocess as postprocess
+from unstructured_inference.models.table_postprocess import Rect
+
+
+class UnstructuredTableTransformerModel(UnstructuredModel):
+ """Unstructured model wrapper for table-transformer."""
+
+ def __init__(self):
+ pass
+
+ def predict(self, x: Image):
+ """Predict table structure deferring to run_prediction"""
+ super().predict(x)
+ return self.run_prediction(x)
+
+ def initialize(
+ self,
+ model: Union[str, Path, TableTransformerForObjectDetection] = None,
+ device: Optional[str] = "cuda" if torch.cuda.is_available() else "cpu",
+ ):
+ """Loads the donut model using the specified parameters"""
+ self.device = device
+ self.feature_extractor = DetrImageProcessor()
+
+ try:
+ logging.info("Loading the table structure model ...")
+ self.model = TableTransformerForObjectDetection.from_pretrained(model)
+ self.model.eval()
+
+ except EnvironmentError:
+ logging.critical("Failed to initialize the model.")
+ logging.critical("Ensure that the model is correct")
+ raise ImportError(
+ "Review the parameters to initialize a UnstructuredTableTransformerModel obj"
+ )
+ self.model.to(device)
+
+ def run_prediction(self, x: Image):
+ """Predict table structure"""
+ with torch.no_grad():
+ encoding = self.feature_extractor(x, return_tensors="pt").to(self.device)
+ outputs_structure = self.model(**encoding)
+
+ if platform.machine() == "x86_64":
+ from unstructured_inference.models import paddle_ocr
+
+ paddle_result = paddle_ocr.load_agent().ocr(np.array(x), cls=True)
+
+ tokens = []
+ for idx in range(len(paddle_result)):
+ res = paddle_result[idx]
+ for line in res:
+ xmin = min([i[0] for i in line[0]])
+ ymin = min([i[1] for i in line[0]])
+ xmax = max([i[0] for i in line[0]])
+ ymax = max([i[1] for i in line[0]])
+ tokens.append({"bbox": [xmin, ymin, xmax, ymax], "text": line[1][0]})
+ else:
+ zoom = 6
+ img = cv2.resize(
+ cv2.cvtColor(np.array(x), cv2.COLOR_RGB2BGR),
+ None,
+ fx=zoom,
+ fy=zoom,
+ interpolation=cv2.INTER_CUBIC,
+ )
+
+ kernel = np.ones((1, 1), np.uint8)
+ img = cv2.dilate(img, kernel, iterations=1)
+ img = cv2.erode(img, kernel, iterations=1)
+
+ ocr_df: pd.DataFrame = pytesseract.image_to_data(
+ Image.fromarray(img), output_type="data.frame"
+ )
+
+ ocr_df = ocr_df.dropna()
+
+ tokens = []
+ for idtx in ocr_df.itertuples():
+ tokens.append(
+ {
+ "bbox": [
+ idtx.left / zoom,
+ idtx.top / zoom,
+ (idtx.left + idtx.width) / zoom,
+ (idtx.top + idtx.height) / zoom,
+ ],
+ "text": idtx.text,
+ }
+ )
+
+ sorted(tokens, key=lambda x: x["bbox"][1] * 10000 + x["bbox"][0])
+
+ # 'tokens' is a list of tokens
+ # Need to be in a relative reading order
+ # If no order is provided, use current order
+ for idx, token in enumerate(tokens):
+ if "span_num" not in token:
+ token["span_num"] = idx
+ if "line_num" not in token:
+ token["line_num"] = 0
+ if "block_num" not in token:
+ token["block_num"] = 0
+
+ html = recognize(outputs_structure, x, tokens=tokens, out_html=True)["html"]
+ prediction = html[0] if html else ""
+ return prediction
+
+
+tables_agent: UnstructuredTableTransformerModel = UnstructuredTableTransformerModel()
+
+
+def load_agent():
+ """Loads the Tesseract OCR agent as a global variable to ensure that we only load it once."""
+ global tables_agent
+
+ if not hasattr(tables_agent, "model"):
+ logger.info("Loading the Tesseract OCR agent ...")
+ tables_agent.initialize("microsoft/table-transformer-structure-recognition")
+
+ return
+
+
+def get_class_map(data_type: str):
+ """Defines class map dictionaries"""
+ if data_type == "structure":
+ class_map = {
+ "table": 0,
+ "table column": 1,
+ "table row": 2,
+ "table column header": 3,
+ "table projected row header": 4,
+ "table spanning cell": 5,
+ "no object": 6,
+ }
+ elif data_type == "detection":
+ class_map = {"table": 0, "table rotated": 1, "no object": 2}
+ return class_map
+
+
+structure_class_thresholds = {
+ "table": 0.5,
+ "table column": 0.5,
+ "table row": 0.5,
+ "table column header": 0.5,
+ "table projected row header": 0.5,
+ "table spanning cell": 0.5,
+ "no object": 10,
+}
+
+
+def recognize(outputs: dict, img: Image, tokens: list, out_html: bool = False):
+ """Recognize table elements."""
+ out_formats = {}
+
+ str_class_name2idx = get_class_map("structure")
+ str_class_idx2name = {v: k for k, v in str_class_name2idx.items()}
+ str_class_thresholds = structure_class_thresholds
+
+ # Post-process detected objects, assign class labels
+ objects = outputs_to_objects(outputs, img.size, str_class_idx2name)
+
+ # Further process the detected objects so they correspond to a consistent table
+ tables_structure = objects_to_structures(objects, tokens, str_class_thresholds)
+ # Enumerate all table cells: grid cells and spanning cells
+ tables_cells = [structure_to_cells(structure, tokens)[0] for structure in tables_structure]
+
+ # Convert cells to HTML
+ if out_html:
+ tables_htmls = [cells_to_html(cells) for cells in tables_cells]
+ out_formats["html"] = tables_htmls
+
+ return out_formats
+
+
+def outputs_to_objects(outputs, img_size, class_idx2name):
+ """Output table element types."""
+ m = outputs["logits"].softmax(-1).max(-1)
+ pred_labels = list(m.indices.detach().cpu().numpy())[0]
+ pred_scores = list(m.values.detach().cpu().numpy())[0]
+ pred_bboxes = outputs["pred_boxes"].detach().cpu()[0]
+ pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
+
+ objects = []
+ for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
+ class_label = class_idx2name[int(label)]
+ if not class_label == "no object":
+ objects.append(
+ {
+ "label": class_label,
+ "score": float(score),
+ "bbox": [float(elem) for elem in bbox],
+ }
+ )
+
+ return objects
+
+
+# for output bounding box post-processing
+def box_cxcywh_to_xyxy(x):
+ """Convert rectangle format from center-x, center-y, width, height to
+ x-min, y-min, x-max, y-max."""
+ x_c, y_c, w, h = x.unbind(-1)
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
+ return torch.stack(b, dim=1)
+
+
+def rescale_bboxes(out_bbox, size):
+ """Rescale relative bounding box to box of size given by size."""
+ img_w, img_h = size
+ b = box_cxcywh_to_xyxy(out_bbox)
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
+ return b
+
+
+def iob(bbox1, bbox2):
+ """
+ Compute the intersection area over box area, for bbox1.
+ """
+ intersection = Rect(bbox1).intersect(Rect(bbox2))
+
+ bbox1_area = Rect(bbox1).get_area()
+ if bbox1_area > 0:
+ return intersection.get_area() / bbox1_area
+
+ return 0
+
+
+def objects_to_structures(objects, tokens, class_thresholds):
+ """
+ Process the bounding boxes produced by the table structure recognition model into
+ a *consistent* set of table structures (rows, columns, spanning cells, headers).
+ This entails resolving conflicts/overlaps, and ensuring the boxes meet certain alignment
+ conditions (for example: rows should all have the same width, etc.).
+ """
+
+ tables = [obj for obj in objects if obj["label"] == "table"]
+ table_structures = []
+
+ for table in tables:
+ table_objects = [obj for obj in objects if iob(obj["bbox"], table["bbox"]) >= 0.5]
+ table_tokens = [token for token in tokens if iob(token["bbox"], table["bbox"]) >= 0.5]
+
+ structure = {}
+
+ columns = [obj for obj in table_objects if obj["label"] == "table column"]
+ rows = [obj for obj in table_objects if obj["label"] == "table row"]
+ column_headers = [obj for obj in table_objects if obj["label"] == "table column header"]
+ spanning_cells = [obj for obj in table_objects if obj["label"] == "table spanning cell"]
+ for obj in spanning_cells:
+ obj["projected row header"] = False
+ projected_row_headers = [
+ obj for obj in table_objects if obj["label"] == "table projected row header"
+ ]
+ for obj in projected_row_headers:
+ obj["projected row header"] = True
+ spanning_cells += projected_row_headers
+ for obj in rows:
+ obj["column header"] = False
+ for header_obj in column_headers:
+ if iob(obj["bbox"], header_obj["bbox"]) >= 0.5:
+ obj["column header"] = True
+
+ # Refine table structures
+ rows = postprocess.refine_rows(rows, table_tokens, class_thresholds["table row"])
+ columns = postprocess.refine_columns(
+ columns, table_tokens, class_thresholds["table column"]
+ )
+
+ # Shrink table bbox to just the total height of the rows
+ # and the total width of the columns
+ row_rect = Rect()
+ for obj in rows:
+ row_rect.include_rect(obj["bbox"])
+ column_rect = Rect()
+ for obj in columns:
+ column_rect.include_rect(obj["bbox"])
+ table["row_column_bbox"] = [
+ column_rect.x_min,
+ row_rect.y_min,
+ column_rect.x_max,
+ row_rect.y_max,
+ ]
+ table["bbox"] = table["row_column_bbox"]
+
+ # Process the rows and columns into a complete segmented table
+ columns = postprocess.align_columns(columns, table["row_column_bbox"])
+ rows = postprocess.align_rows(rows, table["row_column_bbox"])
+
+ structure["rows"] = rows
+ structure["columns"] = columns
+ structure["column headers"] = column_headers
+ structure["spanning cells"] = spanning_cells
+
+ if len(rows) > 0 and len(columns) > 1:
+ structure = refine_table_structure(structure, class_thresholds)
+
+ table_structures.append(structure)
+
+ return table_structures
+
+
+def refine_table_structure(table_structure, class_thresholds):
+ """
+ Apply operations to the detected table structure objects such as
+ thresholding, NMS, and alignment.
+ """
+ rows = table_structure["rows"]
+ columns = table_structure["columns"]
+
+ # Process the headers
+ column_headers = table_structure["column headers"]
+ column_headers = postprocess.apply_threshold(
+ column_headers, class_thresholds["table column header"]
+ )
+ column_headers = postprocess.nms(column_headers)
+ column_headers = align_headers(column_headers, rows)
+
+ # Process spanning cells
+ spanning_cells = [
+ elem for elem in table_structure["spanning cells"] if not elem["projected row header"]
+ ]
+ projected_row_headers = [
+ elem for elem in table_structure["spanning cells"] if elem["projected row header"]
+ ]
+ spanning_cells = postprocess.apply_threshold(
+ spanning_cells, class_thresholds["table spanning cell"]
+ )
+ projected_row_headers = postprocess.apply_threshold(
+ projected_row_headers, class_thresholds["table projected row header"]
+ )
+ spanning_cells += projected_row_headers
+ # Align before NMS for spanning cells because alignment brings them into agreement
+ # with rows and columns first; if spanning cells still overlap after this operation,
+ # the threshold for NMS can basically be lowered to just above 0
+ spanning_cells = postprocess.align_supercells(spanning_cells, rows, columns)
+ spanning_cells = postprocess.nms_supercells(spanning_cells)
+
+ postprocess.header_supercell_tree(spanning_cells)
+
+ table_structure["columns"] = columns
+ table_structure["rows"] = rows
+ table_structure["spanning cells"] = spanning_cells
+ table_structure["column headers"] = column_headers
+
+ return table_structure
+
+
+def align_headers(headers, rows):
+ """
+ Adjust the header boundary to be the convex hull of the rows it intersects
+ at least 50% of the height of.
+
+ For now, we are not supporting tables with multiple headers, so we need to
+ eliminate anything besides the top-most header.
+ """
+
+ aligned_headers = []
+
+ for row in rows:
+ row["column header"] = False
+
+ header_row_nums = []
+ for header in headers:
+ for row_num, row in enumerate(rows):
+ row_height = row["bbox"][3] - row["bbox"][1]
+ min_row_overlap = max(row["bbox"][1], header["bbox"][1])
+ max_row_overlap = min(row["bbox"][3], header["bbox"][3])
+ overlap_height = max_row_overlap - min_row_overlap
+ if overlap_height / row_height >= 0.5:
+ header_row_nums.append(row_num)
+
+ if len(header_row_nums) == 0:
+ return aligned_headers
+
+ header_rect = Rect()
+ if header_row_nums[0] > 0:
+ header_row_nums = list(range(header_row_nums[0] + 1)) + header_row_nums
+
+ last_row_num = -1
+ for row_num in header_row_nums:
+ if row_num == last_row_num + 1:
+ row = rows[row_num]
+ row["column header"] = True
+ header_rect = header_rect.include_rect(row["bbox"])
+ last_row_num = row_num
+ else:
+ # Break as soon as a non-header row is encountered.
+ # This ignores any subsequent rows in the table labeled as a header.
+ # Having more than 1 header is not supported currently.
+ break
+
+ header = {"bbox": header_rect.get_bbox()}
+ aligned_headers.append(header)
+
+ return aligned_headers
+
+
+def structure_to_cells(table_structure, tokens):
+ """
+ Assuming the row, column, spanning cell, and header bounding boxes have
+ been refined into a set of consistent table structures, process these
+ table structures into table cells. This is a universal representation
+ format for the table, which can later be exported to Pandas or CSV formats.
+ Classify the cells as header/access cells or data cells
+ based on if they intersect with the header bounding box.
+ """
+ columns = table_structure["columns"]
+ rows = table_structure["rows"]
+ spanning_cells = table_structure["spanning cells"]
+ cells = []
+ subcells = []
+ # Identify complete cells and subcells
+ for column_num, column in enumerate(columns):
+ for row_num, row in enumerate(rows):
+ column_rect = Rect(list(column["bbox"]))
+ row_rect = Rect(list(row["bbox"]))
+ cell_rect = row_rect.intersect(column_rect)
+ header = "column header" in row and row["column header"]
+ cell = {
+ "bbox": cell_rect.get_bbox(),
+ "column_nums": [column_num],
+ "row_nums": [row_num],
+ "column header": header,
+ }
+
+ cell["subcell"] = False
+ for spanning_cell in spanning_cells:
+ spanning_cell_rect = Rect(list(spanning_cell["bbox"]))
+ if (
+ spanning_cell_rect.intersect(cell_rect).get_area() / cell_rect.get_area()
+ ) > 0.5:
+ cell["subcell"] = True
+ break
+
+ if cell["subcell"]:
+ subcells.append(cell)
+ else:
+ # cell text = extract_text_inside_bbox(table_spans, cell['bbox'])
+ # cell['cell text'] = cell text
+ cell["projected row header"] = False
+ cells.append(cell)
+
+ for spanning_cell in spanning_cells:
+ spanning_cell_rect = Rect(list(spanning_cell["bbox"]))
+ cell_columns = set()
+ cell_rows = set()
+ cell_rect = None
+ header = True
+ for subcell in subcells:
+ subcell_rect = Rect(list(subcell["bbox"]))
+ subcell_rect_area = subcell_rect.get_area()
+ if (subcell_rect.intersect(spanning_cell_rect).get_area() / subcell_rect_area) > 0.5:
+ if cell_rect is None:
+ cell_rect = Rect(list(subcell["bbox"]))
+ else:
+ cell_rect.include_rect(list(subcell["bbox"]))
+ cell_rows = cell_rows.union(set(subcell["row_nums"]))
+ cell_columns = cell_columns.union(set(subcell["column_nums"]))
+ # By convention here, all subcells must be classified
+ # as header cells for a spanning cell to be classified as a header cell;
+ # otherwise, this could lead to a non-rectangular header region
+ header = header and "column header" in subcell and subcell["column header"]
+ if len(cell_rows) > 0 and len(cell_columns) > 0:
+ cell = {
+ "bbox": cell_rect.get_bbox(),
+ "column_nums": list(cell_columns),
+ "row_nums": list(cell_rows),
+ "column header": header,
+ "projected row header": spanning_cell["projected row header"],
+ }
+ cells.append(cell)
+
+ # Compute a confidence score based on how well the page tokens
+ # slot into the cells reported by the model
+ _, _, cell_match_scores = postprocess.slot_into_containers(cells, tokens)
+ try:
+ mean_match_score = sum(cell_match_scores) / len(cell_match_scores)
+ min_match_score = min(cell_match_scores)
+ confidence_score = (mean_match_score + min_match_score) / 2
+ except ZeroDivisionError:
+ confidence_score = 0
+
+ # Dilate rows and columns before final extraction
+ # dilated_columns = fill_column_gaps(columns, table_bbox)
+ dilated_columns = columns
+ # dilated_rows = fill_row_gaps(rows, table_bbox)
+ dilated_rows = rows
+ for cell in cells:
+ column_rect = Rect()
+ for column_num in cell["column_nums"]:
+ column_rect.include_rect(list(dilated_columns[column_num]["bbox"]))
+ row_rect = Rect()
+ for row_num in cell["row_nums"]:
+ row_rect.include_rect(list(dilated_rows[row_num]["bbox"]))
+ cell_rect = column_rect.intersect(row_rect)
+ cell["bbox"] = cell_rect.get_bbox()
+
+ span_nums_by_cell, _, _ = postprocess.slot_into_containers(
+ cells, tokens, overlap_threshold=0.001, forced_assignment=False
+ )
+
+ for cell, cell_span_nums in zip(cells, span_nums_by_cell):
+ cell_spans = [tokens[num] for num in cell_span_nums]
+ # TODO: Refine how text is extracted; should be character-based, not span-based;
+ # but need to associate
+ cell["cell text"] = postprocess.extract_text_from_spans(
+ cell_spans, remove_integer_superscripts=False
+ )
+ cell["spans"] = cell_spans
+
+ # Adjust the row, column, and cell bounding boxes to reflect the extracted text
+ num_rows = len(rows)
+ rows = postprocess.sort_objects_top_to_bottom(rows)
+ num_columns = len(columns)
+ columns = postprocess.sort_objects_left_to_right(columns)
+ min_y_values_by_row = defaultdict(list)
+ max_y_values_by_row = defaultdict(list)
+ min_x_values_by_column = defaultdict(list)
+ max_x_values_by_column = defaultdict(list)
+ for cell in cells:
+ min_row = min(cell["row_nums"])
+ max_row = max(cell["row_nums"])
+ min_column = min(cell["column_nums"])
+ max_column = max(cell["column_nums"])
+ for span in cell["spans"]:
+ min_x_values_by_column[min_column].append(span["bbox"][0])
+ min_y_values_by_row[min_row].append(span["bbox"][1])
+ max_x_values_by_column[max_column].append(span["bbox"][2])
+ max_y_values_by_row[max_row].append(span["bbox"][3])
+ for row_num, row in enumerate(rows):
+ if len(min_x_values_by_column[0]) > 0:
+ row["bbox"][0] = min(min_x_values_by_column[0])
+ if len(min_y_values_by_row[row_num]) > 0:
+ row["bbox"][1] = min(min_y_values_by_row[row_num])
+ if len(max_x_values_by_column[num_columns - 1]) > 0:
+ row["bbox"][2] = max(max_x_values_by_column[num_columns - 1])
+ if len(max_y_values_by_row[row_num]) > 0:
+ row["bbox"][3] = max(max_y_values_by_row[row_num])
+ for column_num, column in enumerate(columns):
+ if len(min_x_values_by_column[column_num]) > 0:
+ column["bbox"][0] = min(min_x_values_by_column[column_num])
+ if len(min_y_values_by_row[0]) > 0:
+ column["bbox"][1] = min(min_y_values_by_row[0])
+ if len(max_x_values_by_column[column_num]) > 0:
+ column["bbox"][2] = max(max_x_values_by_column[column_num])
+ if len(max_y_values_by_row[num_rows - 1]) > 0:
+ column["bbox"][3] = max(max_y_values_by_row[num_rows - 1])
+ for cell in cells:
+ row_rect = None
+ column_rect = None
+ for row_num in cell["row_nums"]:
+ if row_rect is None:
+ row_rect = Rect(list(rows[row_num]["bbox"]))
+ else:
+ row_rect.include_rect(list(rows[row_num]["bbox"]))
+ for column_num in cell["column_nums"]:
+ if column_rect is None:
+ column_rect = Rect(list(columns[column_num]["bbox"]))
+ else:
+ column_rect.include_rect(list(columns[column_num]["bbox"]))
+ cell_rect = row_rect.intersect(column_rect)
+ if cell_rect.get_area() > 0:
+ cell["bbox"] = cell_rect.get_bbox()
+ pass
+
+ return cells, confidence_score
+
+
+def cells_to_html(cells):
+ """Convert table structure to html format."""
+ cells = sorted(cells, key=lambda k: min(k["column_nums"]))
+ cells = sorted(cells, key=lambda k: min(k["row_nums"]))
+
+ table = ET.Element("table")
+ current_row = -1
+
+ for cell in cells:
+ this_row = min(cell["row_nums"])
+
+ attrib = {}
+ colspan = len(cell["column_nums"])
+ if colspan > 1:
+ attrib["colspan"] = str(colspan)
+ rowspan = len(cell["row_nums"])
+ if rowspan > 1:
+ attrib["rowspan"] = str(rowspan)
+ if this_row > current_row:
+ current_row = this_row
+ if cell["column header"]:
+ cell_tag = "th"
+ row = ET.SubElement(table, "thead")
+ else:
+ cell_tag = "td"
+ row = ET.SubElement(table, "tr")
+ tcell = ET.SubElement(row, cell_tag, attrib=attrib)
+ tcell.text = cell["cell text"]
+
+ return str(ET.tostring(table, encoding="unicode", short_empty_elements=False))