Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions providers/common/ai/docs/hooks/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

.. http://www.apache.org/licenses/LICENSE-2.0

.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.

Common AI Hooks
===============

The common-ai provider ships hooks that bridge an Airflow connection to a specific
LLM framework's model objects. Each hook is a thin adapter: it reads credentials and
config from the connection, then returns native framework objects (a ``pydantic_ai``
``Agent`` / ``Model``, a LangChain ``BaseChatModel`` or ``Embeddings``, an MCP client,
...). Operators and ``@task`` decorators in this provider use these hooks internally.

Choosing a hook
---------------

.. list-table::
:header-rows: 1
:widths: 25 75

* - Hook
- When to use
* - :class:`~airflow.providers.common.ai.hooks.pydantic_ai.PydanticAIHook`
- Default for ``common.ai`` operators (``LLMOperator``, ``AgentOperator``,
``LLMBranchOperator``, ...). Returns a pydantic-ai ``Agent`` / ``Model``.
* - :class:`~airflow.providers.common.ai.hooks.langchain.LangChainHook`
- Direct LangChain access for tasks that compose ``Runnable``\\s, use the
LangChain agent surface, or need LangChain-native chat / embedding model
objects. Independent of the pydantic-ai-backed operators.

Hook guides
-----------

.. toctree::
:maxdepth: 1
:glob:

*
174 changes: 174 additions & 0 deletions providers/common/ai/docs/hooks/langchain.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

.. http://www.apache.org/licenses/LICENSE-2.0

.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.

.. _howto/hook:langchain:

``LangChainHook``
=================

Use :class:`~airflow.providers.common.ai.hooks.langchain.LangChainHook` to
bridge an Airflow connection to `LangChain <https://python.langchain.com/>`__
chat and embedding models. The hook reads credentials (API key, optional base
URL) from the connection and returns configured LangChain model objects via
two universal entry-point functions:

- ``langchain.chat_models.init_chat_model`` for chat models, dispatching to
the right vendor based on the ``provider:name`` prefix.
- ``langchain.embeddings.init_embeddings`` for embedding models, same
dispatch story.

The hook owns its own ``langchain`` connection type so the UI is honest about
which framework a connection configures.

Chat model usage
----------------

Pass ``llm_model`` to the constructor (or set ``extra["model"]`` on the
connection) and call ``get_chat_model()``:

.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_langchain_hook.py
:language: python
:start-after: [START howto_hook_langchain_chat]
:end-before: [END howto_hook_langchain_chat]

The returned model is a LangChain ``BaseChatModel``, so it composes with the
rest of LangChain's runnable surface
(``ChatPromptTemplate`` / ``StrOutputParser`` / ``RunnableSequence`` / ...).

Supported chat providers
~~~~~~~~~~~~~~~~~~~~~~~~

Any model identifier accepted by
`langchain.chat_models.init_chat_model <https://python.langchain.com/api_reference/langchain/chat_models/langchain.chat_models.base.init_chat_model.html>`__
works out of the box. Common identifiers:

- ``openai:gpt-4o``, ``openai:gpt-4o-mini`` -- requires ``langchain-openai``
- ``anthropic:claude-3-7-sonnet`` -- requires ``langchain-anthropic``
- ``groq:llama-3.3-70b-versatile`` -- requires ``langchain-groq``
- ``mistralai:mistral-large-latest`` -- requires ``langchain-mistralai``
- ``ollama:llama3`` -- requires ``langchain-ollama`` (point ``host`` at the Ollama URL)
- ``deepseek:deepseek-chat`` -- requires ``langchain-deepseek``

Cloud providers with non-standard auth (AWS Bedrock, Google Vertex AI, Azure
OpenAI) are not covered by the ``api_key`` + ``base_url`` surface here and are
deferred to per-vendor hooks (mirroring the pydantic-ai cloud-auth subclass
pattern).

Embedding model usage
---------------------

Pass ``embed_model`` to the constructor (or set ``extra["embed_model"]`` on
the connection) and call ``get_embedding_model()``:

.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_langchain_hook.py
:language: python
:start-after: [START howto_hook_langchain_embedding]
:end-before: [END howto_hook_langchain_embedding]

The same hook instance can serve both chat and embedding models when both
identifiers are set:

.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_langchain_hook.py
:language: python
:start-after: [START howto_hook_langchain_chat_and_embedding]
:end-before: [END howto_hook_langchain_chat_and_embedding]

Supported embedding providers
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The hook passes ``api_key`` and (optional) ``base_url`` from the connection to
`langchain.embeddings.init_embeddings <https://reference.langchain.com/python/langchain/embeddings/base/init_embeddings>`__.
Providers whose embedding classes accept this kwarg shape work directly:

- ``openai:text-embedding-3-small``, ``openai:text-embedding-3-large`` -- requires ``langchain-openai``
- ``openai:<model>`` against an OpenAI-compatible endpoint (point ``host`` at
Ollama / vLLM / LM Studio) -- requires ``langchain-openai``

``init_embeddings`` advertises more providers (Cohere, Mistral AI, HuggingFace,
Bedrock, Vertex AI, Azure OpenAI, ...), but their embedding classes expect
provider-specific credential kwargs (``cohere_api_key``, AWS auth chain, GCP
service-account, ...) rather than the generic ``api_key`` / ``base_url`` this
hook forwards. Those are deferred to per-vendor subclasses mirroring the
pydantic-ai pattern (``PydanticAIBedrockHook`` / ``PydanticAIVertexHook`` /
``PydanticAIAzureHook``).

Different connections for chat and embeddings
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

If chat and embeddings live on different API keys (e.g. premium chat key vs
free-tier embeddings key), pass an explicit ``embed_conn_id``. When unset it
falls back to ``llm_conn_id``, so the common one-provider case stays simple:

.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_langchain_hook.py
:language: python
:start-after: [START howto_hook_langchain_different_conns]
:end-before: [END howto_hook_langchain_different_conns]

Connection Configuration
------------------------

The hook reads credentials from the Airflow connection of type ``langchain``:

- **password** -- API key (passed as ``api_key`` to ``init_chat_model`` and
``init_embeddings``).
- **host** -- Optional base URL (passed as ``base_url``; useful for custom
OpenAI-compatible endpoints, Ollama, vLLM).
- **extra** JSON -- ``{"model": "openai:gpt-4o", "embed_model": "openai:text-embedding-3-small"}``
to set default chat and embedding model identifiers on the connection.

Parameters
----------

.. list-table::
:header-rows: 1
:widths: 25 25 50

* - Parameter
- Default
- Description
* - ``llm_conn_id``
- ``langchain_default``
- Airflow connection ID for the LLM provider.
* - ``embed_conn_id``
- ``None`` (falls back to ``llm_conn_id``)
- Optional separate Airflow connection ID for the embedding provider.
Useful when chat and embeddings live on different API keys; in the
common one-provider case, leave unset and the hook reuses ``llm_conn_id``.
* - ``llm_model``
- ``None`` (falls back to ``extra["model"]`` on the connection)
- Chat model identifier in ``provider:name`` form, e.g. ``openai:gpt-4o``.
Only required when calling ``get_chat_model()``.
* - ``embed_model``
- ``None`` (falls back to ``extra["embed_model"]`` on the connection)
- Embedding model identifier in ``provider:name`` form, e.g.
``openai:text-embedding-3-small``. Only required when calling
``get_embedding_model()``.

Dependencies
------------

Install the ``langchain`` extra to use this hook::

pip install apache-airflow-providers-common-ai[langchain]

That extra installs only ``langchain`` itself, since the framework is
vendor-agnostic. Install the LangChain integration package for whichever
provider(s) you intend to use:

- ``langchain-openai`` -- OpenAI and OpenAI-compatible endpoints (Ollama, vLLM)
- ``langchain-anthropic`` -- Anthropic
- ``langchain-groq``, ``langchain-mistralai``, ``langchain-deepseek``, ``langchain-ollama``, ...
2 changes: 1 addition & 1 deletion providers/common/ai/docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

Connection types <connections/pydantic_ai>
MCP connection <connections/mcp>
Hooks <hooks/pydantic_ai>
Hooks <hooks/index>
Toolsets <toolsets>
Operators <operators/index>
HITL Review <hitl_review>
Expand Down
39 changes: 39 additions & 0 deletions providers/common/ai/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ integrations:
- integration-name: MCP Server
external-doc-url: https://modelcontextprotocol.io/
tags: [ai]
- integration-name: LangChain
external-doc-url: https://python.langchain.com/
tags: [ai]

hooks:
- integration-name: Pydantic AI
Expand All @@ -56,6 +59,9 @@ hooks:
- integration-name: MCP Server
python-modules:
- airflow.providers.common.ai.hooks.mcp
- integration-name: LangChain
python-modules:
- airflow.providers.common.ai.hooks.langchain

plugins:
- name: hitl_review
Expand Down Expand Up @@ -313,6 +319,39 @@ connection-types:
type:
- string
- 'null'
- hook-class-name: airflow.providers.common.ai.hooks.langchain.LangChainHook
hook-name: "LangChain"
connection-type: langchain
ui-field-behaviour:
hidden-fields:
- schema
- port
- login
relabeling:
password: API Key
placeholders:
host: "https://api.openai.com/v1 (optional, for custom endpoints / Ollama)"
conn-fields:
model:
label: Chat Model
description: >
Chat model in provider:name format dispatched via
langchain.chat_models.init_chat_model
(e.g. openai:gpt-4o, anthropic:claude-3-7-sonnet).
schema:
type:
- string
- 'null'
embed_model:
label: Embedding Model
description: >
Embedding model in provider:name format dispatched via
langchain.embeddings.init_embeddings
(e.g. openai:text-embedding-3-small, cohere:embed-english-v3.0).
schema:
type:
- string
- 'null'

operators:
- integration-name: Common AI
Expand Down
6 changes: 5 additions & 1 deletion providers/common/ai/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ dependencies = [
"common.sql" = [
"apache-airflow-providers-common-sql"
]
"langchain" = [
"langchain>=1.0.0",
]

[dependency-groups]
dev = [
Expand All @@ -107,7 +110,8 @@ dev = [
# Additional devel dependencies (do not remove this line and add extra development dependencies)
"sqlglot>=30.0.0",
"pydantic-ai-slim[mcp]",
"apache-airflow-providers-common-sql[datafusion]"
"apache-airflow-providers-common-sql[datafusion]",
"langchain>=1.0.0",
]

# To build docs:
Expand Down
Loading
Loading