# 集成检索器


![](img/img_2.png)

In [5]:
# !pip install -q -U langchain openai chromadb tiktoken rank_bm25
# !pip install -q -U  rank_bm25


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [1]:
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers import BM25Retriever
from langchain.vectorstores import Chroma

In [2]:
from typing import Any, Dict, List

from langchain.callbacks.manager import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from langchain.pydantic_v1 import root_validator
from langchain.schema import BaseRetriever, Document


class EnsembleRetriever(BaseRetriever):
    """Retriever that ensembles the multiple retrievers.
        能够集成多个检索器的检索器。
        它使用了排名融合。
    It uses a rank fusion.

    Args:
        要集成的检索器列表
        retrievers: A list of retrievers to ensemble.
        weights: A list of weights corresponding to the retrievers. Defaults to equal
            weighting for all retrievers.与检索器相对应的权重列表。默认情况下，所有检索器的权重相等
        c: A constant added to the rank, controlling the balance between the importance
            of high-ranked items and the consideration given to lower-ranked items.
            Default is 60.加到排名上的一个常数，用于控制高排名项目的重要性与对低排名项目的考虑之间的平衡。默认值为 60 。
    """

    retrievers: List[BaseRetriever]
    weights: List[float]
    c: int = 60

    @root_validator(pre=True)
    def set_weights(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        if not values.get("weights"):
            n_retrievers = len(values["retrievers"])
            values["weights"] = [1 / n_retrievers] * n_retrievers
        return values

    def _get_relevant_documents(
            self,
            query: str,
            *,
            run_manager: CallbackManagerForRetrieverRun,
    ) -> List[Document]:
        """
        Get the relevant documents for a given query.

        Args:
            query: The query to search for.

        Returns:
            A list of reranked documents.
        """

        # Get fused result of the retrievers.获取检索器的融合结果
        fused_documents = self.rank_fusion(query, run_manager)

        return fused_documents

    async def _aget_relevant_documents(
            self,
            query: str,
            *,
            run_manager: AsyncCallbackManagerForRetrieverRun,
    ) -> List[Document]:
        """
        Asynchronously get the relevant documents for a given query.异步获取给定查询的相关文档

        Args:
            query: The query to search for.要搜索的查询

        Returns:
            A list of reranked documents.
        """

        # Get fused result of the retrievers.
        fused_documents = await self.arank_fusion(query, run_manager)

        return fused_documents

    def rank_fusion(
            self, query: str, run_manager: CallbackManagerForRetrieverRun
    ) -> List[Document]:
        """
        Retrieve the results of the retrievers and use rank_fusion_func to get
        the final result.

        Args:
            query: The query to search for.

        Returns:
            A list of reranked documents.
        """

        # Get the results of all retrievers.
        retriever_docs = [
            retriever.get_relevant_documents(
                query, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}")
            )
            for i, retriever in enumerate(self.retrievers)
        ]

        # apply rank fusion
        fused_documents = self.weighted_reciprocal_rank(retriever_docs)

        return fused_documents

    async def arank_fusion(
            self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
    ) -> List[Document]:
        """
        Asynchronously retrieve the results of the retrievers
        and use rank_fusion_func to get the final result.

        Args:
            query: The query to search for.

        Returns:
            A list of reranked documents.
        """

        # Get the results of all retrievers.
        retriever_docs = [
            await retriever.aget_relevant_documents(
                query, callbacks=run_manager.get_child(tag=f"retriever_{i + 1}")
            )
            for i, retriever in enumerate(self.retrievers)
        ]

        # apply rank fusion
        fused_documents = self.weighted_reciprocal_rank(retriever_docs)

        return fused_documents

    def weighted_reciprocal_rank(
            self, doc_lists: List[List[Document]]
    ) -> List[Document]:
        """
        Perform weighted Reciprocal Rank Fusion on multiple rank lists.
        You can find more details about RRF here:对多个排名列表执行加权倒数排名融合
        https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf

        Args:
            doc_lists: A list of rank lists, where each rank list contains unique items.

        Returns:
            list: The final aggregated list of items sorted by their weighted RRF
                    scores in descending order.最终聚合的项目列表，按照其加权 RRF 分数降序排序。
        """
        if len(doc_lists) != len(self.weights):
            raise ValueError(
                "Number of rank lists must be equal to the number of weights."
            )

        # Create a union of all unique documents in the input doc_lists
        all_documents = set()
        for doc_list in doc_lists:
            for doc in doc_list:
                all_documents.add(doc.page_content)

        # Initialize the RRF score dictionary for each document
        rrf_score_dic = {doc: 0.0 for doc in all_documents}

        # Calculate RRF scores for each document
        for doc_list, weight in zip(doc_lists, self.weights):
            for rank, doc in enumerate(doc_list, start=1):
                rrf_score = weight * (1 / (rank + self.c))
                rrf_score_dic[doc.page_content] += rrf_score

        for key, value in rrf_score_dic.items():
            print(f'Key: {key}, Value: {value}')

        # Sort documents by their RRF scores in descending order
        sorted_documents = sorted(
            rrf_score_dic.keys(), key=lambda x: rrf_score_dic[x], reverse=True
        )

        # Map the sorted page_content back to the original document objects
        page_content_to_doc_map = {
            doc.page_content: doc for doc_list in doc_lists for doc in doc_list
        }
        sorted_docs = [
            page_content_to_doc_map[page_content] for page_content in sorted_documents
        ]

        return sorted_docs

In [3]:
# CloudflareWorkersAI
from dotenv import load_dotenv
import os
from langchain_community.llms.cloudflare_workersai import CloudflareWorkersAI

# 加载当前目录下的.env文件
# load_dotenv()
# load_dotenv(override=True) 会重新读取.env
load_dotenv(override=True)

# 现在可以像访问普通环境变量一样访问.env文件中的变量了
account_id = os.getenv('CF_ACCOUNT_ID')
api_token = os.getenv('CF_API_TOKEN')

print(account_id)
print(api_token)

import getpass

model = '@cf/meta/llama-3-8b-instruct'
cf_llm = CloudflareWorkersAI(account_id=account_id, api_token=api_token, model=model)

# 最新的Embedding方式
# cloudflare_workersai
from langchain_community.embeddings.cloudflare_workersai import (
    CloudflareWorkersAIEmbeddings,
)

# //维度是：384
embeddings = CloudflareWorkersAIEmbeddings(
    account_id=account_id,
    api_token=api_token,
    model_name="@cf/baai/bge-small-en-v1.5",
)

8483c3ec7a0cbc54a8d660b5b9002b04
Gcllof8ze6dgtcqFI5FQZ2SD_5tfCD4Db7NuS6jn


In [6]:
doc_list = [
    "I like apples",
    "I like oranges",
    "Apples and oranges are fruits"
]

# initialize the bm25 retriever and Chromadb retriever 
# BM25（维基百科）也被称为 Okapi BM25，是信息检索系统中用于估计文档与给定搜索查询的相关性的一个排名函数。
bm25_retriever = BM25Retriever.from_texts(doc_list)
bm25_retriever.k = 2

docs = bm25_retriever.get_relevant_documents('apple')
print(docs)

vectorstore = Chroma.from_texts(doc_list, embeddings, collection_name="tutorial_2023")
vs_retriever = vectorstore.as_retriever(search_kwargs={"k": 2})

vs_docs = vs_retriever.get_relevant_documents('apple')
print(vs_docs)

# initialize the ensemble retriever 初始化集成检索器
ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, vs_retriever], weights=[0.5, 0.5])

  warn_deprecated(


[Document(page_content='Apples and oranges are fruits'), Document(page_content='I like oranges')]
[Document(page_content='I like apples'), Document(page_content='Apples and oranges are fruits')]


In [7]:
# 集成检索器(多个检索器合在一起)
docs = ensemble_retriever.get_relevant_documents("apple")
docs

Key: I like oranges, Value: 0.008064516129032258
Key: I like apples, Value: 0.00819672131147541
Key: Apples and oranges are fruits, Value: 0.01626123744050767


[Document(page_content='Apples and oranges are fruits'),
 Document(page_content='I like apples'),
 Document(page_content='I like oranges')]