Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions providers/amazon/docs/operators/bedrock.rst
Original file line number Diff line number Diff line change
Expand Up @@ -417,3 +417,32 @@ To update an Amazon Bedrock guardrail configuration, use
:dedent: 4
:start-after: [START howto_operator_bedrock_update_guardrail]
:end-before: [END howto_operator_bedrock_update_guardrail]


.. _howto/operator:BedrockRerankOperator:

Rerank Documents
----------------

To rerank a list of documents based on their relevance to a query using Amazon Bedrock,
you can use :class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockRerankOperator`.

This operator uses the Bedrock Agent Runtime ``Rerank`` API to score and reorder documents,
which is useful for improving RAG pipeline quality by filtering and prioritizing retrieved
results before passing them to a generative model.

.. code-block:: python

from airflow.providers.amazon.aws.operators.bedrock import BedrockRerankOperator

rerank = BedrockRerankOperator(
task_id="rerank_results",
query="What is serverless computing?",
documents=[
{"textDocument": {"text": "AWS Lambda is a serverless compute service."}},
{"textDocument": {"text": "Amazon EC2 provides virtual servers in the cloud."}},
{"textDocument": {"text": "Serverless eliminates infrastructure management."}},
],
model_id="cohere.rerank-v3-5:0",
number_of_results=2,
)
Original file line number Diff line number Diff line change
Expand Up @@ -1334,3 +1334,91 @@ def execute(self, context: Context) -> str:
job_arn = response["jobArn"]
self.log.info("Created evaluation job %s: %s", self.job_name, job_arn)
return job_arn


class BedrockRerankOperator(AwsBaseOperator[BedrockAgentRuntimeHook]):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please create unit tests associated to this operator

"""
Rerank a list of documents based on their relevance to a query using Amazon Bedrock.

Uses the Bedrock Agent Runtime ``Rerank`` API to score and reorder documents
by relevance, which is useful for improving RAG pipeline quality by filtering
and prioritizing retrieved results before passing them to a generative model.

.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BedrockRerankOperator`

:param query: The search query to rerank documents against. (templated)
:param documents: List of documents to rerank. Each document should be a dict
with at minimum a 'textDocument' key containing {'text': '...'}. (templated)
:param model_id: The model ID for the reranking model.
Defaults to 'cohere.rerank-v3-5:0'. (templated)
:param number_of_results: Maximum number of results to return.
If not specified, all documents are returned reranked. (templated)
:param rerank_kwargs: Additional keyword arguments to pass to the Rerank API call. (templated)
"""

aws_hook_class = BedrockAgentRuntimeHook
template_fields: Sequence[str] = aws_template_fields(
"query",
"documents",
"model_id",
"number_of_results",
"rerank_kwargs",
)

def __init__(
self,
query: str,
documents: list[dict[str, Any]],
model_id: str = "cohere.rerank-v3-5:0",
number_of_results: int | None = None,
rerank_kwargs: dict[str, Any] | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.query = query
self.documents = documents
self.model_id = model_id
self.number_of_results = number_of_results
self.rerank_kwargs = rerank_kwargs or {}

def execute(self, context: Context) -> list[dict[str, Any]]:
self.log.info(
"Reranking %d documents with model %s",
len(self.documents),
self.model_id,
)

sources = [{"inlineDocumentSource": doc} for doc in self.documents]

rerank_config: dict[str, Any] = {
"type": "INLINE",
"inlineDocumentSources": sources,
}

queries = [{"type": "TEXT", "textQuery": {"text": self.query}}]

reranking_configuration: dict[str, Any] = {
"bedrockRerankingConfiguration": {
"modelConfiguration": {"modelArn": f"arn:aws:bedrock:{self.hook.region_name}::foundation-model/{self.model_id}"},
}
}

if self.number_of_results is not None:
reranking_configuration["bedrockRerankingConfiguration"]["numberOfResults"] = self.number_of_results

kwargs: dict[str, Any] = prune_dict(
{
"sources": [rerank_config],
"queries": queries,
"rerankingConfiguration": reranking_configuration,
**self.rerank_kwargs,
}
)

response = self.hook.conn.rerank(**kwargs)
results = response.get("results", [])

self.log.info("Reranking complete. Returned %d results.", len(results))
return results
Loading