diff --git a/providers/common/ai/docs/operators/document_loader.rst b/providers/common/ai/docs/operators/document_loader.rst new file mode 100644 index 0000000000000..8a836c37d120e --- /dev/null +++ b/providers/common/ai/docs/operators/document_loader.rst @@ -0,0 +1,297 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +.. _howto/operator:document_loader: + +``DocumentLoaderOperator`` +========================== + +Use :class:`~airflow.providers.common.ai.operators.document_loader.DocumentLoaderOperator` +to parse files into ``list[dict(text, metadata)]`` for downstream embedding +pipelines. The operator bridges Airflow's connectivity layer (hooks that +produce bytes or local files) and the AI embedding layer (operators that +need structured text with metadata). + +The operator is **framework-agnostic** -- it has no dependency on LlamaIndex, +LangChain, or any other AI framework. + +Basic usage +----------- + +``.txt``, ``.md``, ``.csv``, and ``.json`` are handled with zero extra +dependencies: + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_document_loader.py + :language: python + :start-after: [START howto_operator_document_loader_basic] + :end-before: [END howto_operator_document_loader_basic] + +CSV files produce one document per row, with empty cells skipped. JSON files +with a top-level array produce one document per element; a single JSON object +produces one document. By default each dict is flattened into ``"key: value, +key: value"`` text so the embedding sees content tokens rather than JSON +syntax (see the ``json_text_field`` section below for the structured variant). + +PDF parsing +----------- + +Install the ``pdf`` extra to parse PDF files via +`pypdf `__:: + + pip install apache-airflow-providers-common-ai[pdf] + +Each page with extractable text becomes a separate document. Empty pages are +skipped. ``page_number`` is included in the document metadata. + +DOCX parsing +------------ + +Install the ``docx`` extra to parse Word documents via +`python-docx `__:: + + pip install apache-airflow-providers-common-ai[docx] + +All non-empty paragraphs are concatenated into a single document per file. + +.. note:: + + DOCX extraction reads paragraph text only. Tables, headers, footers, and + footnotes are not included. For richer DOCX parsing, use a dedicated + extraction tool (``Unstructured``, ``docling``) as a custom parser + backend. + +Directory mode and filtering +---------------------------- + +Point ``source_path`` at a directory or pass a glob pattern (``**`` enables +recursive matching). Combine with ``file_extensions`` to scope which files +are processed: + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_document_loader.py + :language: python + :start-after: [START howto_operator_document_loader_directory] + :end-before: [END howto_operator_document_loader_directory] + +Directory-mode behavior when ``file_extensions`` is omitted: + +- Files whose name starts with a ``.`` (``.DS_Store``, editor swap files, + ``.gitkeep``, ...) are silently ignored. +- Files whose extension is not in the built-in dispatch map are skipped + with a warning rather than crashing the operator. A glob pattern that + matches an unknown extension is treated as intentional and parsed via + the explicit ``parser`` argument. + +Loading from bytes +------------------ + +When upstream tasks produce file content as bytes (S3, GCS, HTTP, etc.), +pass them via ``source_bytes`` and tell the operator how to interpret them +with ``file_type``. ``source_bytes`` is not a template field because Jinja +would render ``bytes`` as their ``repr`` text, which would break binary +parsing: + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_document_loader.py + :language: python + :start-after: [START howto_operator_document_loader_bytes] + :end-before: [END howto_operator_document_loader_bytes] + +PDF and DOCX bytes are parsed via an in-memory stream -- no temporary files +on disk. + +Structured JSON ingestion +------------------------- + +For arrays of records where one field is the body and the rest are metadata +(article ingestion, ticket exports, ...), set ``json_text_field`` to the key +that holds the text. Every other key on the same item lands in ``metadata``: + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_document_loader.py + :language: python + :start-after: [START howto_operator_document_loader_json_field] + :end-before: [END howto_operator_document_loader_json_field] + +For **arbitrary API data** (Salesforce SOQL results, database query exports), +a ``@task`` that maps fields to text and metadata is still appropriate when +the field shape is more complex than what ``json_text_field`` covers: + +.. code-block:: python + + @task + def transform_cases(records: list[dict]) -> list[dict]: + return [ + { + "text": f"{r['Subject']}\n\n{r['Description']}", + "metadata": {"case_id": r["Id"], "source": "salesforce"}, + } + for r in records + ] + +No chunking +----------- + +The operator parses files into documents; it does **not** split them into +fixed-size chunks. The right chunking strategy depends on the embedding +model and is intentionally left to a downstream text-splitter or embedding +operator (LlamaIndex's ``EmbeddingOperator``, LangChain's text splitters, +...). + +Format coverage roadmap +----------------------- + +The current built-in dispatch covers ``.txt``, ``.md``, ``.csv``, ``.json``, +``.pdf``, ``.docx``. Additional formats are deferred to follow-ups, each +gated behind its own extra so users only install what they need: + +- ``.pptx`` via ``python-pptx`` +- ``.epub`` via ``ebooklib`` +- ``.xlsx`` via ``openpyxl`` +- ``.html`` / ``.htm`` via ``beautifulsoup4`` +- Image OCR (``.png`` / ``.jpg``) via ``pytesseract`` +- Audio transcription via a model call (``LLMOperator`` or ``AgentOperator`` + is a better fit for transcription than this parser) + +For anything not in the dispatch map, set ``parser`` explicitly (``"text"`` +to read as plain text) or write the parser inline in a ``@task`` that calls +``DocumentLoaderOperator`` with ``source_bytes`` for known formats. + +Composing with downstream embedding operators +--------------------------------------------- + +The output format (``list[dict(text, metadata)]``) is designed to feed +directly into embedding operators. With LlamaIndex's ``EmbeddingOperator``: + +.. code-block:: python + + load = DocumentLoaderOperator( + task_id="load", + source_path="/data/docs/*.pdf", + ) + + embed = EmbeddingOperator( + task_id="embed", + documents="{{ ti.xcom_pull(task_ids='load') }}", + llm_conn_id="openai_default", + ) + + load >> embed + +Cloud storage URIs +------------------ + +``source_path`` accepts any URI that +:class:`~airflow.sdk.ObjectStoragePath` resolves via fsspec +(``s3://``, ``gs://``, ``azure://``, ``file://``, ...). Point it at a +single object or a directory; cross-directory globs in cloud URIs are not +supported in this version. + +.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_document_loader.py + :language: python + :start-after: [START howto_operator_document_loader_cloud_uri] + :end-before: [END howto_operator_document_loader_cloud_uri] + +Use ``source_conn_id`` to point at the Airflow connection that holds the +cloud credentials (``aws_default``, ``google_cloud_default``, ...). For +single-file URIs, ``source_conn_id`` works the same way. + +If you'd rather download the file with a dedicated provider operator +first (e.g. to get retry semantics specific to that storage), the +download-then-parse pattern still works: + +.. code-block:: python + + from airflow.providers.amazon.aws.transfers.s3_to_local import S3ToLocalFilesystemOperator + + download = S3ToLocalFilesystemOperator( + task_id="download", + bucket_name="my-bucket", + key="documents/report.pdf", + local_path="/tmp/report.pdf", + ) + + load = DocumentLoaderOperator( + task_id="load", + source_path="/tmp/report.pdf", + ) + + download >> load + +Non-UTF-8 inputs +---------------- + +The text parsers (``.txt`` / ``.md`` / ``.csv`` / ``.json``) and the bytes +path default to UTF-8. To handle Windows-1252 CSVs, files with a leading +``utf-8-sig`` byte-order mark, or any other encoding, set the ``encoding`` +parameter on the operator (and optionally ``encoding_errors="replace"`` to +tolerate mixed-encoding sources at the cost of some character loss). A +failed decode includes the offending file path in the error so +directory-mode runs are easy to diagnose. + +Metadata precedence +------------------- + +Auto-extracted metadata keys -- ``file_name``, ``file_path``, ``row_index``, +``item_index``, ``page_number`` -- take precedence over keys with the same +name in ``metadata_fields``. ``metadata_fields`` fills gaps; it never +overwrites the auto-extracted shape. + +Parameters +---------- + +.. list-table:: + :header-rows: 1 + :widths: 25 75 + + * - Parameter + - Description + * - ``source_path`` + - Local file, directory, or glob pattern, **or** a storage URI + (``s3://``, ``gs://``, ``azure://``, ``file://``) resolved via + :class:`~airflow.sdk.ObjectStoragePath`. ``**`` is recursive for + local globs; cross-directory globs in cloud URIs are not supported. + Mutually exclusive with ``source_bytes``. + * - ``source_conn_id`` + - Airflow connection ID for the cloud-storage credentials used by + ``ObjectStoragePath`` (``aws_default``, ``google_cloud_default``, + ...). Ignored for local paths. + * - ``source_bytes`` + - Raw file bytes from XCom. Requires ``file_type``. Mutually exclusive + with ``source_path``. Not a template field (bytes don't survive Jinja). + * - ``file_type`` + - File extension hint (e.g. ``".pdf"``). Required with ``source_bytes``; + optional with ``source_path`` to override auto-detection. + * - ``parser`` + - Parsing backend. ``"auto"`` (default) picks from the file extension. + Set explicitly to force a backend (e.g. ``"text"`` to treat an + unknown extension as plain text). + * - ``file_extensions`` + - Filter for ``source_path`` directory or glob. When omitted in + directory mode, files whose name starts with a ``.`` are ignored + and unknown-extension files are skipped with a warning. + * - ``metadata_fields`` + - Extra key-value pairs merged into every document's metadata. Does + not override auto-extracted keys. + * - ``encoding`` + - Text encoding for the bytes path and ``.txt`` / ``.md`` / ``.csv`` / + ``.json`` files. Defaults to ``"utf-8"``. + * - ``encoding_errors`` + - How decode errors are handled (``"strict"`` / ``"replace"`` / + ``"ignore"``). Defaults to ``"strict"``. + * - ``json_text_field`` + - When parsing JSON, treat this key as the embedding text; every other + key on the same item lands in ``metadata``. When unset, dicts are + flattened to ``"k: v, k: v"`` so the embedding sees content tokens + rather than JSON syntax. diff --git a/providers/common/ai/docs/operators/index.rst b/providers/common/ai/docs/operators/index.rst index 89ba5d15e6c20..dec108990eee2 100644 --- a/providers/common/ai/docs/operators/index.rst +++ b/providers/common/ai/docs/operators/index.rst @@ -21,7 +21,7 @@ Common AI Operators Choosing the right operator --------------------------- -The common-ai provider ships five operators (and matching ``@task`` decorators). Use this table +The common-ai provider ships several operators (and matching ``@task`` decorators). Use this table to pick the one that fits your use case: .. list-table:: @@ -46,6 +46,9 @@ to pick the one that fits your use case: * - Multi-turn reasoning with tools (DB queries, API calls, etc.) - :class:`~airflow.providers.common.ai.operators.agent.AgentOperator` - ``@task.agent`` + * - Parse files (PDF, DOCX, CSV, etc.) into document dicts for embedding + - :class:`~airflow.providers.common.ai.operators.document_loader.DocumentLoaderOperator` + - *(no decorator)* **LLMOperator / @task.llm** — stateless, single-turn calls. Use this for classification, summarization, extraction, or any prompt that produces one response. Supports structured output @@ -63,6 +66,10 @@ read files) to produce its answer. You configure available tools through ``tools AgentOperator *works* without toolsets — pydantic-ai supports tool-less agents for multi-turn reasoning — but if you don't need tools, ``LLMOperator`` is simpler and more explicit. +**DocumentLoaderOperator** — framework-agnostic file parsing. Use this to convert files +(text, CSV, JSON, PDF, DOCX) into ``list[dict(text, metadata)]`` for downstream embedding. +No AI framework dependency. + Operator guides --------------- diff --git a/providers/common/ai/provider.yaml b/providers/common/ai/provider.yaml index e56dcce6cd6db..92d826ffb0a89 100644 --- a/providers/common/ai/provider.yaml +++ b/providers/common/ai/provider.yaml @@ -42,6 +42,7 @@ integrations: - /docs/apache-airflow-providers-common-ai/operators/llm_branch.rst - /docs/apache-airflow-providers-common-ai/operators/llm_sql.rst - /docs/apache-airflow-providers-common-ai/operators/llm_schema_compare.rst + - /docs/apache-airflow-providers-common-ai/operators/document_loader.rst tags: [ai] - integration-name: Pydantic AI external-doc-url: https://ai.pydantic.dev/ @@ -363,6 +364,7 @@ operators: - airflow.providers.common.ai.operators.llm_branch - airflow.providers.common.ai.operators.llm_sql - airflow.providers.common.ai.operators.llm_schema_compare + - airflow.providers.common.ai.operators.document_loader task-decorators: - class-name: airflow.providers.common.ai.decorators.agent.agent_task diff --git a/providers/common/ai/pyproject.toml b/providers/common/ai/pyproject.toml index da9fe3236252c..18833cd64bc1a 100644 --- a/providers/common/ai/pyproject.toml +++ b/providers/common/ai/pyproject.toml @@ -98,6 +98,8 @@ dependencies = [ "langchain" = [ "langchain>=1.0.0", ] +"pdf" = ["pypdf>=4.0.0"] +"docx" = ["python-docx>=1.0.0"] [dependency-groups] dev = [ diff --git a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_document_loader.py b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_document_loader.py new file mode 100644 index 0000000000000..aa80c77d32d12 --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_document_loader.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAGs demonstrating DocumentLoaderOperator usage patterns. + +Each DAG covers a single pattern. The hook docs reference these via +``.. exampleinclude::`` so the runnable snippets stay in sync. +""" + +from __future__ import annotations + +from airflow.providers.common.ai.operators.document_loader import DocumentLoaderOperator +from airflow.providers.common.compat.sdk import dag, task + + +# [START howto_operator_document_loader_basic] +@dag(schedule=None) +def example_document_loader_basic(): + """Parse a single local file -- the operator infers the format from the suffix.""" + + load_docs = DocumentLoaderOperator( + task_id="load_docs", + source_path="/opt/airflow/data/articles/sample.md", + ) + + @task + def count_chunks(docs: list[dict]) -> int: + return len(docs) + + count_chunks(load_docs.output) + + +# [END howto_operator_document_loader_basic] + +example_document_loader_basic() + + +# [START howto_operator_document_loader_directory] +@dag(schedule=None) +def example_document_loader_directory(): + """Walk a directory recursively, only picking up PDFs and Markdown.""" + + load_docs = DocumentLoaderOperator( + task_id="load_docs", + # `**` matches across subdirectories thanks to glob's recursive mode. + source_path="/opt/airflow/data/library/**/*", + file_extensions=[".pdf", ".md"], + metadata_fields={"corpus": "library_v3"}, + ) + + @task + def summarise(docs: list[dict]) -> dict: + return { + "files": len({d["metadata"]["file_path"] for d in docs}), + "chunks": len(docs), + } + + summarise(load_docs.output) + + +# [END howto_operator_document_loader_directory] + +example_document_loader_directory() + + +# [START howto_operator_document_loader_bytes] +@dag(schedule=None) +def example_document_loader_bytes(): + """Feed raw bytes from an upstream hook (e.g. an S3 download) into the parser.""" + + @task + def fetch_pdf_bytes() -> bytes: + # In real use this would be an S3Hook.read_key, a GCSHook.download_as_bytes, + # or any other byte-producing call. + return b"%PDF-1.4 ..." + + load_docs = DocumentLoaderOperator( + task_id="load_docs", + source_bytes=fetch_pdf_bytes(), + file_type=".pdf", + metadata_fields={"corpus": "uploads"}, + ) + + load_docs + + +# [END howto_operator_document_loader_bytes] + +example_document_loader_bytes() + + +# [START howto_operator_document_loader_json_field] +@dag(schedule=None) +def example_document_loader_json_field(): + """Read an array of records, embedding only the ``body`` field per item. + + Every other key (``title``, ``author``, ``published_at``, ...) lands in + ``metadata`` so it stays available for filtering or display. + """ + + load_docs = DocumentLoaderOperator( + task_id="load_docs", + source_path="/opt/airflow/data/articles.json", + json_text_field="body", + ) + + load_docs + + +# [END howto_operator_document_loader_json_field] + +example_document_loader_json_field() + + +# [START howto_operator_document_loader_cloud_uri] +@dag(schedule=None) +def example_document_loader_cloud_uri(): + """Read PDFs directly from S3 -- no separate download step.""" + + load_docs = DocumentLoaderOperator( + task_id="load_docs", + source_path="s3://my-bucket/reports/", + source_conn_id="aws_default", + file_extensions=[".pdf"], + ) + + load_docs + + +# [END howto_operator_document_loader_cloud_uri] + +example_document_loader_cloud_uri() diff --git a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py index 7cb8513f9e606..d87733bb5ffa0 100644 --- a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py +++ b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py @@ -37,6 +37,7 @@ def get_provider_info(): "/docs/apache-airflow-providers-common-ai/operators/llm_branch.rst", "/docs/apache-airflow-providers-common-ai/operators/llm_sql.rst", "/docs/apache-airflow-providers-common-ai/operators/llm_schema_compare.rst", + "/docs/apache-airflow-providers-common-ai/operators/document_loader.rst", ], "tags": ["ai"], }, @@ -298,6 +299,7 @@ def get_provider_info(): "airflow.providers.common.ai.operators.llm_branch", "airflow.providers.common.ai.operators.llm_sql", "airflow.providers.common.ai.operators.llm_schema_compare", + "airflow.providers.common.ai.operators.document_loader", ], } ], diff --git a/providers/common/ai/src/airflow/providers/common/ai/operators/document_loader.py b/providers/common/ai/src/airflow/providers/common/ai/operators/document_loader.py new file mode 100644 index 0000000000000..7f563d0aa4adc --- /dev/null +++ b/providers/common/ai/src/airflow/providers/common/ai/operators/document_loader.py @@ -0,0 +1,401 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import csv +import glob +import io +import json +from collections.abc import Sequence +from pathlib import Path +from typing import TYPE_CHECKING, Any, BinaryIO + +from airflow.providers.common.compat.sdk import ( + AirflowOptionalProviderFeatureException, + BaseOperator, +) + +if TYPE_CHECKING: + from airflow.sdk import Context + + +# Type alias for path-like inputs that the parsers can read from. ``Path`` is +# the local filesystem; ``ObjectStoragePath`` covers ``s3://``, ``gs://``, +# ``azure://``, ``file://``, ... via fsspec. Both expose the methods we need +# (``read_bytes``, ``open``, ``name``, ``suffix``) so the parsers stay +# polymorphic. +FilePathT = Any # Path | ObjectStoragePath + + +class DocumentLoaderOperator(BaseOperator): + """ + Parse files into ``list[dict(text, metadata)]`` for downstream embedding. + + Bridges Airflow's connectivity layer (hooks that produce bytes or local + files) and the AI embedding layer (operators that need structured text + with metadata). Framework-agnostic: no LlamaIndex, LangChain, or other + AI framework dependency. + + Built-in parsers handle ``.txt``, ``.md``, ``.csv``, and ``.json`` with + zero extra dependencies. PDF and DOCX support require optional packages + installable via extras:: + + pip install apache-airflow-providers-common-ai[pdf] # pypdf + pip install apache-airflow-providers-common-ai[docx] # python-docx + + Provide exactly one of ``source_path`` or ``source_bytes``. When using + ``source_bytes``, ``file_type`` is required so the operator knows which + parser to use. + + The operator is intentionally a **loader**: it does not split documents + into fixed-size chunks. Pass the output to a downstream text-splitter or + embedding operator if you need chunking. + + :param source_path: A local path, glob pattern, or storage URI + (``s3://``, ``gs://``, ``azure://``, ``file://``, ...). Cloud URIs + go through :class:`~airflow.sdk.ObjectStoragePath` / fsspec. + ``**`` enables recursive matching for local globs. Cloud URIs + accept a single file or a directory; cross-directory globs in a + cloud URI are not supported in this version. + :param source_conn_id: Airflow connection ID used by + ``ObjectStoragePath`` for cloud URIs (``aws_default``, + ``google_cloud_default``, ...). Ignored for local paths. + :param source_bytes: Raw file bytes, typically from XCom. + :param file_type: File extension hint when using ``source_bytes`` + (e.g. ``".pdf"``). Also accepted with ``source_path`` to override + auto-detection. + :param parser: Parsing backend selection. ``"auto"`` (default) picks the + backend from the file extension. + :param file_extensions: When ``source_path`` is a directory or glob, + only process files whose extension is in this list. When omitted, + the operator processes only files whose extension is known to the + built-in dispatch (others are skipped with a warning) and silently + ignores files whose name starts with a dot. + :param metadata_fields: Extra key-value pairs merged into every + document's ``metadata`` dict. Auto-extracted fields such as + ``file_name``, ``file_path``, ``row_index``, ``item_index``, and + ``page_number`` take precedence over keys with the same name. + :param encoding: Text encoding used for ``.txt``/``.md``/``.csv``/``.json`` + and for the bytes path. Defaults to ``"utf-8"``. + :param encoding_errors: How decode errors are handled. Defaults to + ``"strict"``; set to ``"replace"`` or ``"ignore"`` to tolerate + mixed-encoding inputs at the cost of some character loss. + :param json_text_field: When parsing JSON, treat this key as the + embedding text and put every other key into ``metadata``. Applies + to each item when the top-level JSON is a list, or to the object + when it is a single dict. When ``None`` (default), the operator + flattens dicts into ``"k: v, k: v"`` text (same shape as the CSV + parser). + """ + + template_fields: Sequence[str] = ( + "source_path", + "source_conn_id", + "file_type", + "file_extensions", + "parser", + "metadata_fields", + ) + + EXTENSION_BACKEND_MAP: dict[str, str] = { + ".txt": "text", + ".md": "text", + ".csv": "csv", + ".json": "json", + ".pdf": "pypdf", + ".docx": "python-docx", + } + + def __init__( + self, + *, + source_path: str | None = None, + source_conn_id: str | None = None, + source_bytes: bytes | None = None, + file_type: str | None = None, + parser: str = "auto", + file_extensions: list[str] | None = None, + metadata_fields: dict[str, Any] | None = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + json_text_field: str | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if source_path is not None and source_bytes is not None: + raise ValueError("Provide exactly one of 'source_path' or 'source_bytes', not both.") + if source_path is None and source_bytes is None: + raise ValueError("Provide exactly one of 'source_path' or 'source_bytes'.") + if source_bytes is not None and file_type is None: + raise ValueError("'file_type' is required when using 'source_bytes' (e.g. '.pdf').") + + self.source_path = source_path + self.source_conn_id = source_conn_id + self.source_bytes = source_bytes + self.file_type = file_type + self.parser = parser + self.file_extensions = file_extensions + self.metadata_fields = metadata_fields + self.encoding = encoding + self.encoding_errors = encoding_errors + self.json_text_field = json_text_field + + def execute(self, context: Context) -> list[dict[str, Any]]: + if self.source_bytes is not None: + assert self.file_type is not None # noqa: S101 -- enforced in __init__ + documents = self._parse_bytes(self.source_bytes, self.file_type) + file_count = 1 + else: + assert self.source_path is not None # noqa: S101 -- enforced in __init__ + files = self._resolve_files(self.source_path) + if not files: + raise FileNotFoundError(f"No files found matching '{self.source_path}'.") + file_count = len(files) + documents = [] + for file_path in files: + ext = self.file_type or file_path.suffix.lower() + parsed = self._parse_file(file_path, ext) + for doc in parsed: + doc["metadata"]["file_name"] = file_path.name + doc["metadata"]["file_path"] = str(file_path) + documents.extend(parsed) + + if self.metadata_fields: + for doc in documents: + # Auto-extracted keys (file_name, page_number, ...) take precedence. + for key, value in self.metadata_fields.items(): + doc["metadata"].setdefault(key, value) + + self.log.info("Parsed %d documents from %d file(s)", len(documents), file_count) + return documents + + def _resolve_files(self, source_path: str) -> list[FilePathT]: + # A storage URI (``s3://``, ``gs://``, ``file://``, ...) goes through + # ObjectStoragePath / fsspec; a bare local path keeps the existing + # glob behaviour. The heuristic is intentionally simple: presence of + # ``://`` indicates a URI. + if "://" in source_path: + return self._resolve_remote_files(source_path) + return self._resolve_local_files(source_path) + + def _resolve_local_files(self, source_path: str) -> list[Path]: + path = Path(source_path) + if path.is_file(): + return [path] + + if path.is_dir(): + candidates = sorted(p for p in path.iterdir() if not p.name.startswith(".")) + is_directory_mode = True + else: + # `recursive=True` makes `**` match across directories per the docstring. + candidates = [Path(p) for p in sorted(glob.glob(source_path, recursive=True))] + is_directory_mode = False + + return self._filter_files([p for p in candidates if p.is_file()], is_directory_mode=is_directory_mode) + + def _resolve_remote_files(self, source_path: str) -> list[FilePathT]: + from airflow.sdk import ObjectStoragePath + + root = ObjectStoragePath(source_path, conn_id=self.source_conn_id) + try: + if root.is_file(): + return [root] + except FileNotFoundError: + # Some fsspec backends raise instead of returning False. + pass + + if not root.is_dir(): + raise FileNotFoundError( + f"Cloud URI '{source_path}' is neither a file nor a directory. " + "Cross-directory globs in cloud URIs aren't supported here; " + "point ``source_path`` at a single object or a directory." + ) + + candidates = sorted( + (p for p in root.iterdir() if not p.name.startswith(".")), + key=str, + ) + return self._filter_files([p for p in candidates if p.is_file()], is_directory_mode=True) + + def _filter_files(self, results: list[FilePathT], *, is_directory_mode: bool) -> list[FilePathT]: + if self.file_extensions: + allowed = {(ext if ext.startswith(".") else f".{ext}").lower() for ext in self.file_extensions} + return [p for p in results if p.suffix.lower() in allowed] + + if is_directory_mode: + # No explicit filter in directory mode: skip files we don't know + # how to parse rather than crashing on the first stray file + # (``.DS_Store``, editor swap files, etc.). A glob is treated as + # intentional and parsed via the explicit ``parser`` argument. + known = set(self.EXTENSION_BACKEND_MAP.keys()) + unknown = [p for p in results if p.suffix.lower() not in known] + if unknown: + self.log.warning( + "Skipping %d file(s) with unrecognised extension: %s", + len(unknown), + ", ".join(sorted({p.suffix or "" for p in unknown})), + ) + return [p for p in results if p.suffix.lower() in known] + + return results + + def _parse_bytes(self, raw: bytes, file_type: str) -> list[dict[str, Any]]: + ext = file_type if file_type.startswith(".") else f".{file_type}" + backend = self._resolve_backend(ext) + + if backend == "pypdf": + return self._parse_pdf_stream(io.BytesIO(raw)) + if backend == "python-docx": + return self._parse_docx_stream(io.BytesIO(raw)) + + text = self._decode(raw, source_hint=f"") + if backend == "csv": + return self._parse_csv_text(text) + if backend == "json": + return self._parse_json_text(text) + return [{"text": text, "metadata": {}}] + + def _parse_file(self, file_path: Path, ext: str) -> list[dict[str, Any]]: + backend = self._resolve_backend(ext) + + if backend == "text": + return self._parse_text(file_path) + if backend == "csv": + return self._parse_csv(file_path) + if backend == "json": + return self._parse_json(file_path) + if backend == "pypdf": + with file_path.open("rb") as fh: + return self._parse_pdf_stream(fh) + if backend == "python-docx": + with file_path.open("rb") as fh: + return self._parse_docx_stream(fh) + + raise ValueError(f"No parser found for backend '{backend}'.") + + def _resolve_backend(self, ext: str) -> str: + if self.parser != "auto": + return self.parser + + ext = ext.lower() + if ext not in self.EXTENSION_BACKEND_MAP: + supported = ", ".join(sorted(self.EXTENSION_BACKEND_MAP.keys())) + raise ValueError( + f"No parser registered for extension '{ext}'. " + f"Supported extensions: {supported}. " + f"Set 'parser' explicitly to override auto-detection." + ) + return self.EXTENSION_BACKEND_MAP[ext] + + def _decode(self, raw: bytes, *, source_hint: str) -> str: + try: + return raw.decode(self.encoding, errors=self.encoding_errors) + except UnicodeDecodeError as e: + raise ValueError( + f"Failed to decode {source_hint} as {self.encoding!r}: {e}. " + f"Pass encoding=... or encoding_errors='replace' to tolerate this." + ) from e + + def _read_text(self, file_path: Path) -> str: + return self._decode(file_path.read_bytes(), source_hint=str(file_path)) + + def _parse_text(self, file_path: Path) -> list[dict[str, Any]]: + return [{"text": self._read_text(file_path), "metadata": {}}] + + def _parse_csv(self, file_path: Path) -> list[dict[str, Any]]: + return self._parse_csv_text(self._read_text(file_path)) + + def _parse_csv_text(self, text: str) -> list[dict[str, Any]]: + reader = csv.DictReader(io.StringIO(text)) + documents = [] + for row_idx, row in enumerate(reader): + # Skip empty cells to avoid noisy "col: ," in the text. + row_text = ", ".join(f"{k}: {v}" for k, v in row.items() if v != "") + documents.append({"text": row_text, "metadata": {"row_index": row_idx}}) + return documents + + def _parse_json(self, file_path: Path) -> list[dict[str, Any]]: + return self._parse_json_text(self._read_text(file_path)) + + def _parse_json_text(self, text: str) -> list[dict[str, Any]]: + data = json.loads(text) + if isinstance(data, list): + return [self._json_item_to_doc(item, item_index=idx) for idx, item in enumerate(data)] + return [self._json_item_to_doc(data, item_index=None)] + + def _json_item_to_doc(self, item: Any, *, item_index: int | None) -> dict[str, Any]: + metadata: dict[str, Any] = {} + if item_index is not None: + metadata["item_index"] = item_index + + if isinstance(item, str): + text = item + elif isinstance(item, dict): + if self.json_text_field is not None: + # Pull the named field out as the text; everything else goes + # to metadata. Common pattern for "ingest article body, keep + # title/author/url for filtering". + text_value = item.get(self.json_text_field, "") + text = ( + text_value if isinstance(text_value, str) else json.dumps(text_value, ensure_ascii=False) + ) + for k, v in item.items(): + if k == self.json_text_field: + continue + metadata[k] = v + else: + # No text field declared: flatten dict to "k: v, k: v" so the + # embedding sees content tokens, not JSON syntax. Mirrors the + # CSV parser's behaviour. + text = ", ".join(f"{k}: {v}" for k, v in item.items() if v not in (None, "")) + else: + text = json.dumps(item, ensure_ascii=False) + + return {"text": text, "metadata": metadata} + + def _parse_pdf_stream(self, stream: BinaryIO) -> list[dict[str, Any]]: + try: + from pypdf import PdfReader + except ImportError as e: + raise AirflowOptionalProviderFeatureException(e) + + reader = PdfReader(stream) + documents = [] + for page_num, page in enumerate(reader.pages): + text = page.extract_text() or "" + if text.strip(): + documents.append({"text": text, "metadata": {"page_number": page_num + 1}}) + return documents + + def _parse_docx_stream(self, stream: BinaryIO) -> list[dict[str, Any]]: + """ + Parse a DOCX stream into documents. + + Extracts paragraph text only. Tables, headers, footers, and footnotes + are not included. For richer DOCX parsing, plug in a dedicated + extraction tool (``Unstructured``, ``docling``) as a custom parser + backend. + """ + try: + from docx import Document + except ImportError as e: + raise AirflowOptionalProviderFeatureException(e) + + doc = Document(stream) + paragraphs = [p.text for p in doc.paragraphs if p.text.strip()] + text = "\n\n".join(paragraphs) + return [{"text": text, "metadata": {}}] diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_document_loader.py b/providers/common/ai/tests/unit/common/ai/operators/test_document_loader.py new file mode 100644 index 0000000000000..bb9db83347cf3 --- /dev/null +++ b/providers/common/ai/tests/unit/common/ai/operators/test_document_loader.py @@ -0,0 +1,596 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import logging +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.providers.common.ai.operators.document_loader import DocumentLoaderOperator + + +class TestDocumentLoaderInit: + def test_template_fields_render_source_path_and_metadata(self): + """ + Behavioral check that the templated fields actually get rendered. + Replaces the previous tautological assertion that just round-tripped + the class attribute. + """ + op = DocumentLoaderOperator( + task_id="test", + source_path="/data/{{ ds }}/*.pdf", + file_type="{{ var.value.preferred_ext }}", + metadata_fields={"run_id": "{{ run_id }}"}, + ) + # Make sure each one is in template_fields so render_template_fields + # would substitute them. + assert "source_path" in op.template_fields + assert "file_type" in op.template_fields + assert "file_extensions" in op.template_fields + assert "parser" in op.template_fields + assert "metadata_fields" in op.template_fields + # source_bytes intentionally not templated -- Jinja stringifies bytes + # to their repr, which would break binary parsing. + assert "source_bytes" not in op.template_fields + + def test_both_sources_raises(self): + with pytest.raises(ValueError, match="not both"): + DocumentLoaderOperator(task_id="test", source_path="/tmp/file.txt", source_bytes=b"hello") + + def test_neither_source_raises(self): + with pytest.raises(ValueError, match="Provide exactly one"): + DocumentLoaderOperator(task_id="test") + + def test_source_bytes_without_file_type_raises(self): + with pytest.raises(ValueError, match="file_type"): + DocumentLoaderOperator(task_id="test", source_bytes=b"hello") + + def test_empty_bytes_without_file_type_raises(self): + with pytest.raises(ValueError, match="file_type"): + DocumentLoaderOperator(task_id="test", source_bytes=b"") + + +class TestTextParser: + def test_txt_file(self, tmp_path): + f = tmp_path / "doc.txt" + f.write_text("Hello world", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + result = op.execute(context=MagicMock()) + + assert len(result) == 1 + assert result[0]["text"] == "Hello world" + assert result[0]["metadata"]["file_name"] == "doc.txt" + + def test_md_file(self, tmp_path): + f = tmp_path / "readme.md" + f.write_text("# Title\n\nSome content", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + result = op.execute(context=MagicMock()) + + assert len(result) == 1 + assert "# Title" in result[0]["text"] + + +class TestCsvParser: + def test_csv_one_doc_per_row(self, tmp_path): + f = tmp_path / "data.csv" + f.write_text("name,age\nAlice,30\nBob,25\n", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + result = op.execute(context=MagicMock()) + + assert len(result) == 2 + assert "Alice" in result[0]["text"] + assert "Bob" in result[1]["text"] + assert result[0]["metadata"]["row_index"] == 0 + assert result[1]["metadata"]["row_index"] == 1 + + def test_csv_empty_cells_skipped(self, tmp_path): + f = tmp_path / "data.csv" + # Bob has no age -- "age: " should not appear in his row text. + f.write_text("name,age\nAlice,30\nBob,\n", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + result = op.execute(context=MagicMock()) + + assert "age: " not in result[1]["text"] + assert "Bob" in result[1]["text"] + + def test_csv_from_bytes(self): + raw = b"col1,col2\nval1,val2\n" + op = DocumentLoaderOperator(task_id="test", source_bytes=raw, file_type=".csv") + result = op.execute(context=MagicMock()) + + assert len(result) == 1 + assert "val1" in result[0]["text"] + + +class TestJsonParser: + def test_json_array_flattens_dicts(self, tmp_path): + f = tmp_path / "items.json" + data = [{"title": "First", "tag": "alpha"}, {"title": "Second", "tag": "beta"}] + f.write_text(json.dumps(data), encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + result = op.execute(context=MagicMock()) + + assert len(result) == 2 + # Embedding should see "title: First, tag: alpha" rather than the raw + # JSON syntax tokens. + assert "title: First" in result[0]["text"] + assert "tag: alpha" in result[0]["text"] + assert result[0]["text"].startswith("title:") # no leading "{" + assert result[0]["metadata"]["item_index"] == 0 + + def test_json_single_object_flattens(self, tmp_path): + f = tmp_path / "config.json" + f.write_text('{"key": "value", "other": "thing"}', encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + result = op.execute(context=MagicMock()) + + assert len(result) == 1 + assert "key: value" in result[0]["text"] + assert "other: thing" in result[0]["text"] + + def test_json_string_primitives(self, tmp_path): + f = tmp_path / "strings.json" + f.write_text('["alpha", "beta", "gamma"]', encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + result = op.execute(context=MagicMock()) + + assert len(result) == 3 + assert result[0]["text"] == "alpha" + assert result[1]["text"] == "beta" + assert result[2]["text"] == "gamma" + + def test_json_text_field_pulls_body_keeps_rest_as_metadata(self, tmp_path): + f = tmp_path / "articles.json" + data = [ + {"title": "Hello", "body": "First article body.", "author": "Alice"}, + {"title": "World", "body": "Second article body.", "author": "Bob"}, + ] + f.write_text(json.dumps(data), encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(f), json_text_field="body") + result = op.execute(context=MagicMock()) + + assert len(result) == 2 + assert result[0]["text"] == "First article body." + assert result[0]["metadata"]["title"] == "Hello" + assert result[0]["metadata"]["author"] == "Alice" + assert "body" not in result[0]["metadata"] + + def test_json_from_bytes(self): + raw = b'[{"a": 1}, {"b": 2}]' + op = DocumentLoaderOperator(task_id="test", source_bytes=raw, file_type=".json") + result = op.execute(context=MagicMock()) + + assert len(result) == 2 + + +def _make_mock_pypdf_module(mock_reader): + """Create a fake pypdf module with a PdfReader that returns mock_reader.""" + mock_module = MagicMock() + mock_module.PdfReader = MagicMock(return_value=mock_reader) + return mock_module + + +def _make_mock_docx_module(mock_doc): + """Create a fake docx module with a Document that returns mock_doc.""" + mock_module = MagicMock() + mock_module.Document = MagicMock(return_value=mock_doc) + return mock_module + + +class TestPdfParser: + def test_pdf_parsing(self, tmp_path): + mock_page_1 = MagicMock() + mock_page_1.extract_text.return_value = "Page one content" + mock_page_2 = MagicMock() + mock_page_2.extract_text.return_value = "Page two content" + + mock_reader = MagicMock() + mock_reader.pages = [mock_page_1, mock_page_2] + + f = tmp_path / "report.pdf" + f.write_bytes(b"fake pdf bytes") + + mock_pypdf = _make_mock_pypdf_module(mock_reader) + with patch.dict("sys.modules", {"pypdf": mock_pypdf}): + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + result = op.execute(context=MagicMock()) + + assert len(result) == 2 + assert result[0]["text"] == "Page one content" + assert result[0]["metadata"]["page_number"] == 1 + assert result[1]["metadata"]["page_number"] == 2 + + def test_pdf_from_bytes_uses_stream_no_tempfile(self, tmp_path): + """Bytes-mode parsing should go through BytesIO, never a temp file.""" + mock_page = MagicMock() + mock_page.extract_text.return_value = "Streamed content" + mock_reader = MagicMock() + mock_reader.pages = [mock_page] + + mock_pypdf = _make_mock_pypdf_module(mock_reader) + with patch.dict("sys.modules", {"pypdf": mock_pypdf}): + op = DocumentLoaderOperator(task_id="test", source_bytes=b"%PDF-1.4 ...", file_type=".pdf") + result = op.execute(context=MagicMock()) + + assert len(result) == 1 + assert result[0]["text"] == "Streamed content" + # PdfReader should have been called once with a stream (BytesIO), + # not a file path string. + mock_pypdf.PdfReader.assert_called_once() + (call_arg,) = mock_pypdf.PdfReader.call_args.args + import io as _io + + assert isinstance(call_arg, _io.BytesIO) + + def test_pdf_skips_empty_pages(self, tmp_path): + mock_page = MagicMock() + mock_page.extract_text.return_value = " " + mock_reader = MagicMock() + mock_reader.pages = [mock_page] + + f = tmp_path / "empty.pdf" + f.write_bytes(b"fake pdf") + + mock_pypdf = _make_mock_pypdf_module(mock_reader) + with patch.dict("sys.modules", {"pypdf": mock_pypdf}): + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + result = op.execute(context=MagicMock()) + + assert len(result) == 0 + + def test_pdf_missing_raises_optional_feature_exception(self, tmp_path): + from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException + + f = tmp_path / "doc.pdf" + f.write_bytes(b"fake pdf") + + with patch.dict("sys.modules", {"pypdf": None}): + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + with pytest.raises(AirflowOptionalProviderFeatureException): + op.execute(context=MagicMock()) + + +class TestDocxParser: + def test_docx_parsing(self, tmp_path): + mock_para_1 = MagicMock() + mock_para_1.text = "First paragraph" + mock_para_2 = MagicMock() + mock_para_2.text = "Second paragraph" + mock_para_empty = MagicMock() + mock_para_empty.text = " " + + mock_doc_obj = MagicMock() + mock_doc_obj.paragraphs = [mock_para_1, mock_para_empty, mock_para_2] + + f = tmp_path / "doc.docx" + f.write_bytes(b"fake docx") + + mock_docx = _make_mock_docx_module(mock_doc_obj) + with patch.dict("sys.modules", {"docx": mock_docx}): + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + result = op.execute(context=MagicMock()) + + assert len(result) == 1 + assert "First paragraph" in result[0]["text"] + assert "Second paragraph" in result[0]["text"] + + def test_docx_from_bytes_uses_stream_no_tempfile(self): + mock_para = MagicMock() + mock_para.text = "Stream paragraph" + mock_doc_obj = MagicMock() + mock_doc_obj.paragraphs = [mock_para] + + mock_docx = _make_mock_docx_module(mock_doc_obj) + with patch.dict("sys.modules", {"docx": mock_docx}): + op = DocumentLoaderOperator(task_id="test", source_bytes=b"fake docx", file_type=".docx") + result = op.execute(context=MagicMock()) + + assert "Stream paragraph" in result[0]["text"] + mock_docx.Document.assert_called_once() + (call_arg,) = mock_docx.Document.call_args.args + import io as _io + + assert isinstance(call_arg, _io.BytesIO) + + def test_docx_missing_raises_optional_feature_exception(self, tmp_path): + from airflow.providers.common.compat.sdk import AirflowOptionalProviderFeatureException + + f = tmp_path / "doc.docx" + f.write_bytes(b"fake docx") + + with patch.dict("sys.modules", {"docx": None}): + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + with pytest.raises(AirflowOptionalProviderFeatureException): + op.execute(context=MagicMock()) + + +class TestFileDiscovery: + def test_glob_multiple_files(self, tmp_path): + (tmp_path / "a.txt").write_text("file a", encoding="utf-8") + (tmp_path / "b.txt").write_text("file b", encoding="utf-8") + (tmp_path / "c.md").write_text("file c", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(tmp_path / "*.txt")) + result = op.execute(context=MagicMock()) + + assert len(result) == 2 + texts = {doc["text"] for doc in result} + assert texts == {"file a", "file b"} + + def test_recursive_glob(self, tmp_path): + nested = tmp_path / "year" / "month" + nested.mkdir(parents=True) + (tmp_path / "top.txt").write_text("top", encoding="utf-8") + (nested / "deep.txt").write_text("deep", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(tmp_path / "**" / "*.txt")) + result = op.execute(context=MagicMock()) + + texts = {doc["text"] for doc in result} + assert texts == {"top", "deep"} + + def test_directory_source(self, tmp_path): + (tmp_path / "x.txt").write_text("hello", encoding="utf-8") + (tmp_path / "y.md").write_text("world", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(tmp_path)) + result = op.execute(context=MagicMock()) + + assert len(result) == 2 + + def test_directory_mode_skips_dotfiles(self, tmp_path): + (tmp_path / "keep.txt").write_text("keep", encoding="utf-8") + (tmp_path / ".DS_Store").write_bytes(b"\x00\x00") + (tmp_path / ".hidden.txt").write_text("nope", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(tmp_path)) + result = op.execute(context=MagicMock()) + + # Only the non-dotfile is parsed; .DS_Store and .hidden.txt are ignored. + assert len(result) == 1 + assert result[0]["text"] == "keep" + + def test_directory_mode_warns_and_skips_unknown_extensions(self, tmp_path, caplog): + (tmp_path / "keep.txt").write_text("keep", encoding="utf-8") + (tmp_path / "stray.xyz").write_text("ignored", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(tmp_path)) + with caplog.at_level(logging.WARNING): + result = op.execute(context=MagicMock()) + + assert len(result) == 1 + assert result[0]["text"] == "keep" + assert any(".xyz" in record.message for record in caplog.records) + + def test_file_extensions_filter(self, tmp_path): + (tmp_path / "keep.txt").write_text("keep me", encoding="utf-8") + (tmp_path / "skip.md").write_text("skip me", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(tmp_path), file_extensions=[".txt"]) + result = op.execute(context=MagicMock()) + + assert len(result) == 1 + assert result[0]["text"] == "keep me" + + def test_empty_directory_raises_file_not_found(self, tmp_path): + op = DocumentLoaderOperator(task_id="test", source_path=str(tmp_path)) + with pytest.raises(FileNotFoundError, match="No files found"): + op.execute(context=MagicMock()) + + def test_unknown_extension_on_single_file_raises(self, tmp_path): + f = tmp_path / "data.xyz" + f.write_text("some data", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + with pytest.raises(ValueError, match="No parser registered"): + op.execute(context=MagicMock()) + + def test_nonexistent_glob_raises_file_not_found(self, tmp_path): + op = DocumentLoaderOperator(task_id="test", source_path=str(tmp_path / "*.nope")) + with pytest.raises(FileNotFoundError, match="No files found"): + op.execute(context=MagicMock()) + + def test_file_extensions_case_insensitive(self, tmp_path): + (tmp_path / "keep.txt").write_text("keep me", encoding="utf-8") + (tmp_path / "skip.md").write_text("skip me", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(tmp_path), file_extensions=[".TXT"]) + result = op.execute(context=MagicMock()) + + assert len(result) == 1 + assert result[0]["text"] == "keep me" + + +class TestCloudUriDispatch: + """``source_path`` containing a URI scheme routes through ObjectStoragePath.""" + + @patch("airflow.sdk.ObjectStoragePath") + def test_single_object_uri_returns_one_document(self, mock_osp_cls): + # `str(mock_obj)` returns whatever MagicMock renders; we only assert + # the file_name field, not file_path, so leaving __str__ default is + # fine and avoids mypy's method-assign complaint. + mock_obj = MagicMock() + mock_obj.is_file.return_value = True + mock_obj.suffix = ".txt" + mock_obj.name = "report.txt" + mock_obj.read_bytes.return_value = b"cloud content" + mock_osp_cls.return_value = mock_obj + + op = DocumentLoaderOperator( + task_id="test", + source_path="s3://bucket/dir/report.txt", + source_conn_id="aws_default", + ) + result = op.execute(context=MagicMock()) + + mock_osp_cls.assert_called_once_with("s3://bucket/dir/report.txt", conn_id="aws_default") + assert len(result) == 1 + assert result[0]["text"] == "cloud content" + assert result[0]["metadata"]["file_name"] == "report.txt" + + @patch("airflow.sdk.ObjectStoragePath") + def test_directory_uri_iterates_children(self, mock_osp_cls): + # Root is a directory; iterdir yields two text files. + def _mock_child(name: str, content: bytes): + child = MagicMock() + child.is_file.return_value = True + child.name = name + child.suffix = "." + name.rsplit(".", 1)[-1] + child.read_bytes.return_value = content + return child + + a = _mock_child("a.txt", b"alpha") + b = _mock_child("b.txt", b"beta") + + root = MagicMock() + root.is_file.return_value = False + root.is_dir.return_value = True + root.iterdir.return_value = [a, b] + mock_osp_cls.return_value = root + + op = DocumentLoaderOperator(task_id="test", source_path="s3://bucket/dir/") + result = op.execute(context=MagicMock()) + + assert {doc["text"] for doc in result} == {"alpha", "beta"} + + @patch("airflow.sdk.ObjectStoragePath") + def test_neither_file_nor_dir_uri_raises(self, mock_osp_cls): + bad = MagicMock() + bad.is_file.return_value = False + bad.is_dir.return_value = False + mock_osp_cls.return_value = bad + + op = DocumentLoaderOperator(task_id="test", source_path="s3://bucket/missing") + with pytest.raises(FileNotFoundError, match="neither a file nor a directory"): + op.execute(context=MagicMock()) + + +class TestEncoding: + def test_strict_utf8_default_raises_with_path_context(self, tmp_path): + f = tmp_path / "latin1.csv" + # \xff is invalid in UTF-8; surface the file path in the error. + f.write_bytes(b"\xff\xfe header\nrow\n") + + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + with pytest.raises(ValueError, match=str(f)): + op.execute(context=MagicMock()) + + def test_encoding_errors_replace_tolerates_garbage(self, tmp_path): + f = tmp_path / "mixed.txt" + f.write_bytes(b"hello \xff world") + + op = DocumentLoaderOperator( + task_id="test", + source_path=str(f), + encoding_errors="replace", + ) + result = op.execute(context=MagicMock()) + + assert "hello" in result[0]["text"] + assert "world" in result[0]["text"] + + def test_alternative_encoding_succeeds(self, tmp_path): + f = tmp_path / "latin1.txt" + f.write_bytes("café".encode("latin-1")) + + op = DocumentLoaderOperator(task_id="test", source_path=str(f), encoding="latin-1") + result = op.execute(context=MagicMock()) + + assert "café" in result[0]["text"] + + +class TestOutputShape: + def test_every_item_has_text_and_metadata(self, tmp_path): + (tmp_path / "a.txt").write_text("doc a", encoding="utf-8") + (tmp_path / "b.txt").write_text("doc b", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(tmp_path / "*.txt")) + result = op.execute(context=MagicMock()) + + for doc in result: + assert "text" in doc + assert "metadata" in doc + assert isinstance(doc["text"], str) + assert isinstance(doc["metadata"], dict) + + def test_metadata_fields_appended(self, tmp_path): + f = tmp_path / "doc.txt" + f.write_text("content", encoding="utf-8") + + op = DocumentLoaderOperator( + task_id="test", + source_path=str(f), + metadata_fields={"source": "test_suite", "version": 2}, + ) + result = op.execute(context=MagicMock()) + + assert result[0]["metadata"]["source"] == "test_suite" + assert result[0]["metadata"]["version"] == 2 + + def test_metadata_fields_do_not_override_auto_extracted(self, tmp_path): + """Auto-extracted file_name wins over a same-key entry in metadata_fields.""" + f = tmp_path / "report.txt" + f.write_text("content", encoding="utf-8") + + op = DocumentLoaderOperator( + task_id="test", + source_path=str(f), + metadata_fields={"file_name": "spoofed", "extra": "kept"}, + ) + result = op.execute(context=MagicMock()) + + assert result[0]["metadata"]["file_name"] == "report.txt" + assert result[0]["metadata"]["extra"] == "kept" + + def test_file_metadata_included(self, tmp_path): + f = tmp_path / "report.txt" + f.write_text("content", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(f)) + result = op.execute(context=MagicMock()) + + assert result[0]["metadata"]["file_name"] == "report.txt" + assert "file_path" in result[0]["metadata"] + + def test_source_bytes_no_file_metadata(self): + op = DocumentLoaderOperator(task_id="test", source_bytes=b"hello", file_type=".txt") + result = op.execute(context=MagicMock()) + + assert len(result) == 1 + assert result[0]["text"] == "hello" + assert "file_name" not in result[0]["metadata"] + + def test_explicit_parser_override(self, tmp_path): + f = tmp_path / "data.log" + f.write_text("log line", encoding="utf-8") + + op = DocumentLoaderOperator(task_id="test", source_path=str(f), parser="text") + result = op.execute(context=MagicMock()) + + assert len(result) == 1 + assert result[0]["text"] == "log line" diff --git a/uv.lock b/uv.lock index 73e405dd33856..fa0f1ed32d7d4 100644 --- a/uv.lock +++ b/uv.lock @@ -4209,12 +4209,14 @@ bedrock = [ common-sql = [ { name = "apache-airflow-providers-common-sql" }, ] +docx = [ + { name = "python-docx" }, +] google = [ { name = "pydantic-ai-slim", extra = ["google"] }, ] langchain = [ { name = "langchain" }, - { name = "langchain-openai" }, ] mcp = [ { name = "pydantic-ai-slim", extra = ["mcp"] }, @@ -4225,6 +4227,9 @@ openai = [ parquet = [ { name = "pyarrow" }, ] +pdf = [ + { name = "pypdf" }, +] sql = [ { name = "apache-airflow-providers-common-sql" }, { name = "sqlglot" }, @@ -4238,6 +4243,7 @@ dev = [ { name = "apache-airflow-providers-common-sql", extra = ["datafusion"] }, { name = "apache-airflow-providers-standard" }, { name = "apache-airflow-task-sdk" }, + { name = "langchain" }, { name = "pydantic-ai-slim", extra = ["mcp"] }, { name = "sqlglot" }, ] @@ -4255,7 +4261,6 @@ requires-dist = [ { name = "fastavro", marker = "python_full_version >= '3.14' and extra == 'avro'", specifier = ">=1.12.1" }, { name = "fastavro", marker = "python_full_version < '3.14' and extra == 'avro'", specifier = ">=1.10.0" }, { name = "langchain", marker = "extra == 'langchain'", specifier = ">=1.0.0" }, - { name = "langchain-openai", marker = "extra == 'langchain'", specifier = ">=0.3.0" }, { name = "pyarrow", marker = "python_full_version >= '3.14' and extra == 'parquet'", specifier = ">=22.0.0" }, { name = "pyarrow", marker = "python_full_version < '3.14' and extra == 'parquet'", specifier = ">=18.0.0" }, { name = "pydantic-ai-slim", specifier = ">=1.34.0" }, @@ -4264,9 +4269,11 @@ requires-dist = [ { name = "pydantic-ai-slim", extras = ["google"], marker = "extra == 'google'" }, { name = "pydantic-ai-slim", extras = ["mcp"], marker = "extra == 'mcp'" }, { name = "pydantic-ai-slim", extras = ["openai"], marker = "extra == 'openai'" }, + { name = "pypdf", marker = "extra == 'pdf'", specifier = ">=4.0.0" }, + { name = "python-docx", marker = "extra == 'docx'", specifier = ">=1.0.0" }, { name = "sqlglot", marker = "extra == 'sql'", specifier = ">=30.0.0" }, ] -provides-extras = ["anthropic", "bedrock", "google", "openai", "mcp", "avro", "parquet", "sql", "common-sql", "langchain"] +provides-extras = ["anthropic", "bedrock", "google", "openai", "mcp", "avro", "parquet", "sql", "common-sql", "langchain", "pdf", "docx"] [package.metadata.requires-dev] dev = [ @@ -4277,6 +4284,7 @@ dev = [ { name = "apache-airflow-providers-common-sql", extras = ["datafusion"], editable = "providers/common/sql" }, { name = "apache-airflow-providers-standard", editable = "providers/standard" }, { name = "apache-airflow-task-sdk", editable = "task-sdk" }, + { name = "langchain", specifier = ">=1.0.0" }, { name = "pydantic-ai-slim", extras = ["mcp"] }, { name = "sqlglot", specifier = ">=30.0.0" }, ] @@ -14587,20 +14595,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0f/1a/86c38c27b81913a1c6c12448cab55defb5a1097c7dc9a4cea83f55477a2d/langchain_core-1.4.0-py3-none-any.whl", hash = "sha256:23cbbdb46e38ddd1dd5247e6167e96013eae74bea4c5949c550809970a9e565c", size = 548120, upload-time = "2026-05-11T18:42:33.992Z" }, ] -[[package]] -name = "langchain-openai" -version = "1.2.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "langchain-core" }, - { name = "openai" }, - { name = "tiktoken" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9a/0e/d8e16c28aa67106d285e63b8ffc04c5af68341e345ce24a0751dbf2e167e/langchain_openai-1.2.1.tar.gz", hash = "sha256:ee4480b787706361b7125fad46930589a624df87aa158c6986ef1fad10d10675", size = 1146092, upload-time = "2026-04-24T19:46:43.328Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/dc/55/2865b18ee3a3dd11160b8c4b2cf37e75bf2a4a8d1d38868ffffc7b7cc180/langchain_openai-1.2.1-py3-none-any.whl", hash = "sha256:a80732185030d4f453dda6c25feef46f645f665423fdffe38ae3edf1ac3c6c4d", size = 98626, upload-time = "2026-04-24T19:46:41.971Z" }, -] - [[package]] name = "langchain-protocol" version = "0.0.15" @@ -18862,6 +18856,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/bd/c038d7cc38edc1aa5bf91ab8068b63d4308c66c4c8bb3cbba7dfbc049f9c/pyparsing-3.3.2-py3-none-any.whl", hash = "sha256:850ba148bd908d7e2411587e247a1e4f0327839c40e2e5e6d05a007ecc69911d", size = 122781, upload-time = "2026-01-21T03:57:55.912Z" }, ] +[[package]] +name = "pypdf" +version = "6.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/58/6dd97d78a4b17a7a6b9d1c6ad23895abc41f0fdc49c553cc05bdfdcc36d0/pypdf-6.11.0.tar.gz", hash = "sha256:062b51c81b0910e6d2755e99e1c5547a0a23b7d0a32322af66240d8edcfabe87", size = 6453975, upload-time = "2026-05-09T13:26:48.955Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/b1/68feb7eb3b99f0c020b414234825f4a5d70e0126c18d933770e8c93a35fc/pypdf-6.11.0-py3-none-any.whl", hash = "sha256:769394d5756d5b304c9b6bef88b54b1816b328e7e6fc9254e625529a15ed4ab8", size = 338819, upload-time = "2026-05-09T13:26:46.904Z" }, +] + [[package]] name = "pyproject-hooks" version = "1.2.0" @@ -19215,6 +19221,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/6f/a05a317a66fee0aad270011461f1a63a453ed12471249f172f7d2e2bc7b4/python_discovery-1.3.1-py3-none-any.whl", hash = "sha256:ed188687ebb3b82c01a17cd5ac62fc94d9f6487a7f1a0f9dfe89753fec91039c", size = 33185, upload-time = "2026-05-12T20:53:34.969Z" }, ] +[[package]] +name = "python-docx" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "lxml" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a9/f7/eddfe33871520adab45aaa1a71f0402a2252050c14c7e3009446c8f4701c/python_docx-1.2.0.tar.gz", hash = "sha256:7bc9d7b7d8a69c9c02ca09216118c86552704edc23bac179283f2e38f86220ce", size = 5723256, upload-time = "2025-06-16T20:46:27.921Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/00/1e03a4989fa5795da308cd774f05b704ace555a70f9bf9d3be057b680bcf/python_docx-1.2.0-py3-none-any.whl", hash = "sha256:3fd478f3250fbbbfd3b94fe1e985955737c145627498896a8a6bf81f4baf66c7", size = 252987, upload-time = "2025-06-16T20:46:22.506Z" }, +] + [[package]] name = "python-dotenv" version = "1.2.2"