diff --git a/.gitignore b/.gitignore index 6de83ed8..4c6c70d9 100755 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ __pycache__/ .idea/ openapi.json openapi_client.json +.env diff --git a/_test_unstructured_client/integration/test_decorators.py b/_test_unstructured_client/integration/test_decorators.py index e1cc73e0..732df1a9 100644 --- a/_test_unstructured_client/integration/test_decorators.py +++ b/_test_unstructured_client/integration/test_decorators.py @@ -5,6 +5,7 @@ import httpx import json +import os import pytest import requests from deepdiff import DeepDiff @@ -19,7 +20,7 @@ from unstructured_client._hooks.custom import form_utils from unstructured_client._hooks.custom import split_pdf_hook -FAKE_KEY = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" +FAKE_KEY = os.getenv("UNSTRUCTURED_API_KEY") or "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" @pytest.mark.parametrize("concurrency_level", [1, 2, 5]) @@ -472,3 +473,44 @@ async def mock_send(_, request: httpx.Request, **kwargs): assert mock_endpoint_called assert res.status_code == 200 + + +@pytest.mark.parametrize( + ("filename", "chunking_strategy", "expected_elements_num"), + [ + # -- Paid strategy -- + ("_sample_docs/layout-parser-paper.pdf", "by_page", 16), # 16 pages, 133 elements w/o chunking + ("_sample_docs/layout-parser-paper.pdf", shared.ChunkingStrategy.BY_PAGE, 16), + # -- Open source strategy -- + ("_sample_docs/layout-parser-paper.pdf", "by_title", -1), # unsure what the correct number is atm + ("_sample_docs/layout-parser-paper.pdf", shared.ChunkingStrategy.BY_TITLE, -1), + ], +) +def test_chunking( + filename: str, + chunking_strategy: str| shared.ChunkingStrategy, + expected_elements_num: int, +): + + client = UnstructuredClient(api_key_auth=FAKE_KEY) + + with open(filename, "rb") as f: + files = shared.Files( + content=f.read(), + file_name=filename, + ) + + parameters = shared.PartitionParameters( + files=files, + chunking_strategy=chunking_strategy, # type: ignore + split_pdf_page=False, # -- Testing splitting as potential issue + ) + + req = operations.PartitionRequest( + partition_parameters=parameters + ) + + resp = client.general.partition(request=req) + assert len(resp.elements) == expected_elements_num + assert all(element.get("type") == "CompositeElement" for element in resp.elements) +