diff --git a/providers/amazon/docs/operators/bedrock.rst b/providers/amazon/docs/operators/bedrock.rst index 6cfd217b4000d..2707bb632dd21 100644 --- a/providers/amazon/docs/operators/bedrock.rst +++ b/providers/amazon/docs/operators/bedrock.rst @@ -417,3 +417,32 @@ To update an Amazon Bedrock guardrail configuration, use :dedent: 4 :start-after: [START howto_operator_bedrock_update_guardrail] :end-before: [END howto_operator_bedrock_update_guardrail] + + +.. _howto/operator:BedrockRerankOperator: + +Rerank Documents +---------------- + +To rerank a list of documents based on their relevance to a query using Amazon Bedrock, +you can use :class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockRerankOperator`. + +This operator uses the Bedrock Agent Runtime ``Rerank`` API to score and reorder documents, +which is useful for improving RAG pipeline quality by filtering and prioritizing retrieved +results before passing them to a generative model. + +.. code-block:: python + + from airflow.providers.amazon.aws.operators.bedrock import BedrockRerankOperator + + rerank = BedrockRerankOperator( + task_id="rerank_results", + query="What is serverless computing?", + documents=[ + {"textDocument": {"text": "AWS Lambda is a serverless compute service."}}, + {"textDocument": {"text": "Amazon EC2 provides virtual servers in the cloud."}}, + {"textDocument": {"text": "Serverless eliminates infrastructure management."}}, + ], + model_id="cohere.rerank-v3-5:0", + number_of_results=2, + ) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/bedrock.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/bedrock.py index d4c6c7fb36ed4..3a73133b15a6f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/bedrock.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/bedrock.py @@ -1334,3 +1334,91 @@ def execute(self, context: Context) -> str: job_arn = response["jobArn"] self.log.info("Created evaluation job %s: %s", self.job_name, job_arn) return job_arn + + +class BedrockRerankOperator(AwsBaseOperator[BedrockAgentRuntimeHook]): + """ + Rerank a list of documents based on their relevance to a query using Amazon Bedrock. + + Uses the Bedrock Agent Runtime ``Rerank`` API to score and reorder documents + by relevance, which is useful for improving RAG pipeline quality by filtering + and prioritizing retrieved results before passing them to a generative model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BedrockRerankOperator` + + :param query: The search query to rerank documents against. (templated) + :param documents: List of documents to rerank. Each document should be a dict + with at minimum a 'textDocument' key containing {'text': '...'}. (templated) + :param model_id: The model ID for the reranking model. + Defaults to 'cohere.rerank-v3-5:0'. (templated) + :param number_of_results: Maximum number of results to return. + If not specified, all documents are returned reranked. (templated) + :param rerank_kwargs: Additional keyword arguments to pass to the Rerank API call. (templated) + """ + + aws_hook_class = BedrockAgentRuntimeHook + template_fields: Sequence[str] = aws_template_fields( + "query", + "documents", + "model_id", + "number_of_results", + "rerank_kwargs", + ) + + def __init__( + self, + query: str, + documents: list[dict[str, Any]], + model_id: str = "cohere.rerank-v3-5:0", + number_of_results: int | None = None, + rerank_kwargs: dict[str, Any] | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.query = query + self.documents = documents + self.model_id = model_id + self.number_of_results = number_of_results + self.rerank_kwargs = rerank_kwargs or {} + + def execute(self, context: Context) -> list[dict[str, Any]]: + self.log.info( + "Reranking %d documents with model %s", + len(self.documents), + self.model_id, + ) + + sources = [{"inlineDocumentSource": doc} for doc in self.documents] + + rerank_config: dict[str, Any] = { + "type": "INLINE", + "inlineDocumentSources": sources, + } + + queries = [{"type": "TEXT", "textQuery": {"text": self.query}}] + + reranking_configuration: dict[str, Any] = { + "bedrockRerankingConfiguration": { + "modelConfiguration": {"modelArn": f"arn:aws:bedrock:{self.hook.region_name}::foundation-model/{self.model_id}"}, + } + } + + if self.number_of_results is not None: + reranking_configuration["bedrockRerankingConfiguration"]["numberOfResults"] = self.number_of_results + + kwargs: dict[str, Any] = prune_dict( + { + "sources": [rerank_config], + "queries": queries, + "rerankingConfiguration": reranking_configuration, + **self.rerank_kwargs, + } + ) + + response = self.hook.conn.rerank(**kwargs) + results = response.get("results", []) + + self.log.info("Reranking complete. Returned %d results.", len(results)) + return results