# MapReduceDocumentsChain

## 源码分析

### 类定义

In [None]:
class MapReduceDocumentsChain(BaseCombineDocumentsChain):
    """Combining documents by mapping a chain over them, then combining results.

    We first call `llm_chain` on each document individually, passing in the
    `page_content` and any other kwargs. This is the `map` step.

    We then process the results of that `map` step in a `reduce` step. This should
    likely be a ReduceDocumentsChain.
    """

`MapReduceDocumentsChain` 是一种处理文档的链式操作类，首先将 `llm_chain` 应用于每个文档（`map` 步骤），然后在 `reduce` 步骤中合并这些结果。该类继承自 `BaseCombineDocumentsChain`，用于实现文档的 `map-reduce` 处理。

### 类属性

#### `llm_chain: LLMChain`

In [None]:
llm_chain: LLMChain

- **说明**：这是应用于每个文档的链，通常是一个语言模型链（LLMChain）。
- **作用**：在 `map` 步骤中，对每个文档单独调用 `llm_chain`，生成初步结果。

#### `reduce_documents_chain: BaseCombineDocumentsChain`

In [None]:
reduce_documents_chain: BaseCombineDocumentsChain

- **说明**：用于合并 `map` 步骤生成的结果。
- **作用**：在 `reduce` 步骤中，调用该链合并 `map` 步骤生成的结果，生成最终输出。

#### `document_variable_name: str`

In [None]:
document_variable_name: str

- **说明**：指定传递给 `llm_chain` 的文档变量名。
- **作用**：在处理文档时，将文档内容映射到这个变量名上，以便传递给 LLMChain。

#### `return_intermediate_steps: bool`

In [None]:
return_intermediate_steps: bool = False

- **说明**：是否返回中间步骤的结果。
- **作用**：如果设置为 `True`，在最终输出中会包含 `map` 步骤生成的中间结果。

### 方法

#### `get_output_schema`

In [None]:
def get_output_schema(self, config: Optional[RunnableConfig] = None) -> Type[BaseModel]:
    if self.return_intermediate_steps:
        return create_model(
            "MapReduceDocumentsOutput",
            **{
                self.output_key: (str, None),
                "intermediate_steps": (List[str], None),
            },
        )
    return super().get_output_schema(config)

- 说明：定义输出的模式。
- 作用：如果 `return_intermediate_steps` 为 `True`，该方法会返回包含中间步骤的输出模式。

#### `output_keys`

In [None]:
@property
def output_keys(self) -> List[str]:
    """Expect input key.

    :meta private:
    """
    _output_keys = super().output_keys
    if self.return_intermediate_steps:
        _output_keys = _output_keys + ["intermediate_steps"]
    return _output_keys

- 说明：返回期望的输出键。
- 作用：根据 `return_intermediate_steps` 的设置，决定是否包含中间步骤的输出键。

#### `get_reduce_chain`

In [None]:
@root_validator(pre=True)
def get_reduce_chain(cls, values: Dict) -> Dict:
    """For backwards compatibility."""
    if "combine_document_chain" in values:
        if "reduce_documents_chain" in values:
            raise ValueError(
                "Both `reduce_documents_chain` and `combine_document_chain` "
                "cannot be provided at the same time. `combine_document_chain` "
                "is deprecated, please only provide `reduce_documents_chain`"
            )
        combine_chain = values["combine_document_chain"]
        collapse_chain = values.get("collapse_document_chain")
        reduce_chain = ReduceDocumentsChain(
            combine_documents_chain=combine_chain,
            collapse_documents_chain=collapse_chain,
        )
        values["reduce_documents_chain"] = reduce_chain
        del values["combine_document_chain"]
        if "collapse_document_chain" in values:
            del values["collapse_document_chain"]

    return values

- 说明：在使用时兼容旧的 `combine_document_chain` 参数。
- 作用：确保向后兼容性，将旧的 `combine_document_chain` 转换为新的 `reduce_documents_chain`。

#### `get_return_intermediate_steps`

In [None]:
@root_validator(pre=True)
def get_return_intermediate_steps(cls, values: Dict) -> Dict:
    """For backwards compatibility."""
    if "return_map_steps" in values:
        values["return_intermediate_steps"] = values["return_map_steps"]
        del values["return_map_steps"]
    return values

- 说明：处理向后兼容性，支持旧的 `return_map_steps` 参数。
- 作用：将旧的 `return_map_steps` 参数转换为新的 `return_intermediate_steps`。

#### `get_default_document_variable_name`

In [None]:
@root_validator(pre=True)
def get_default_document_variable_name(cls, values: Dict) -> Dict:
    """Get default document variable name, if not provided."""
    if "llm_chain" not in values:
        raise ValueError("llm_chain must be provided")

    llm_chain_variables = values["llm_chain"].prompt.input_variables
    if "document_variable_name" not in values:
        if len(llm_chain_variables) == 1:
            values["document_variable_name"] = llm_chain_variables[0]
        else:
            raise ValueError(
                "document_variable_name must be provided if there are "
                "multiple llm_chain input_variables"
            )
    else:
        if values["document_variable_name"] not in llm_chain_variables:
            raise ValueError(
                f"document_variable_name {values['document_variable_name']} was "
                f"not found in llm_chain input_variables: {llm_chain_variables}"
            )
    return values

- 说明：如果没有提供 `document_variable_name`，则推断默认值。
- 作用：根据 `llm_chain` 的提示模板自动推断出文档变量名。

#### `combine_docs`

In [None]:
def combine_docs(
    self,
    docs: List[Document],
    token_max: Optional[int] = None,
    callbacks: Callbacks = None,
    **kwargs: Any,
) -> Tuple[str, dict]:
    """Combine documents in a map reduce manner.

    Combine by mapping first chain over all documents, then reducing the results.
    This reducing can be done recursively if needed (if there are many documents).
    """
    map_results = self.llm_chain.apply(
        # FYI - this is parallelized and so it is fast.
        [{self.document_variable_name: d.page_content, **kwargs} for d in docs],
        callbacks=callbacks,
    )
    question_result_key = self.llm_chain.output_key
    result_docs = [
        Document(page_content=r[question_result_key], metadata=docs[i].metadata)
        # This uses metadata from the docs, and the textual results from `results`
        for i, r in enumerate(map_results)
    ]
    result, extra_return_dict = self.reduce_documents_chain.combine_docs(
        result_docs, token_max=token_max, callbacks=callbacks, **kwargs
    )
    if self.return_intermediate_steps:
        intermediate_steps = [r[question_result_key] for r in map_results]
        extra_return_dict["intermediate_steps"] = intermediate_steps
    return result, extra_return_dict

- 说明：按 `map-reduce` 的方式合并文档。
- 作用：
  - `map` 步骤：对每个文档调用 `llm_chain` 生成初步结果。
  - `reduce` 步骤：将初步结果传递给 `reduce_documents_chain` 合并生成最终输出。
  - 可选地返回中间步骤的结果。

#### `acombine_docs`

In [None]:
async def acombine_docs(
    self,
    docs: List[Document],
    token_max: Optional[int] = None,
    callbacks: Callbacks = None,
    **kwargs: Any,
) -> Tuple[str, dict]:
    """Async combine documents in a map reduce manner.

    Combine by mapping first chain over all documents, then reducing the results.
    This reducing can be done recursively if needed (if there are many documents).
    """
    map_results = await self.llm_chain.aapply(
        # FYI - this is parallelized and so it is fast.
        [{**{self.document_variable_name: d.page_content}, **kwargs} for d in docs],
        callbacks=callbacks,
    )
    question_result_key = self.llm_chain.output_key
    result_docs = [
        Document(page_content=r[question_result_key], metadata=docs[i].metadata)
        # This uses metadata from the docs, and the textual results from `results`
        for i, r in enumerate(map_results)
    ]
    result, extra_return_dict = await self.reduce_documents_chain.acombine_docs(
        result_docs, token_max=token_max, callbacks=callbacks, **kwargs
    )
    if self.return_intermediate_steps:
        intermediate_steps = [r[question_result_key] for r in map_results]
        extra_return_dict["intermediate_steps"] = intermediate_steps
    return result, extra_return_dict

- 说明：`combine_docs` 的异步版本。
- 作用：
  - 以异步方式执行 `map-reduce` 处理。
  - 使用 `map` 步骤生成初步结果，并将其传递给 `reduce_documents_chain` 进行最终合并。

#### `_chain_type`

In [2]:
@property
def _chain_type(self) -> str:
    return "map_reduce_documents_chain"

- 说明：返回链类型。
- 作用：标识此链的类型为 `"map_reduce_documents_chain"`。

### 总结

这个 `MapReduceDocumentsChain` 类是一个强大的工具，可以将复杂的文档处理工作分解成多个步骤，通过并行和递归的方式高效地处理大量文档，并最终生成一个合并的输出。

## demo

In [29]:
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
from langchain.chains.llm import LLMChain
from langchain_community.llms import OpenAI
from langchain.prompts import PromptTemplate
from langchain_community.chat_models import ChatTongyi
from langchain_core.documents.base import Document
import os


In [None]:

# 从环境变量中获取API密钥，用于初始化 ChatTongyi 模型
api_key = os.getenv("KEY_TONGYI")
if not api_key:
    raise ValueError("API Key is not set. Please ensure that the 'KEY_TONGYI' environment variable is set.")


# 初始化 ChatTongyi 模型，设置文本生成的温度参数，温度越低生成的文本越接近输入
llm = ChatTongyi(
    dashscope_api_key=api_key,
    temperature=0,  # 设置生成文本的倾向，值越小生成的文本越接近输入
    streaming=True
)


In [None]:
document_prompt = PromptTemplate(
  input_variables=["page_content"],
  template="{page_content}"
)

llm_chain = LLMChain(
  llm=llm,
  prompt=document_prompt,
)
reduce_prompt = PromptTemplate.from_template("总结这些评论:{context}")

reduce_llm_chain = LLMChain(
  llm=llm,
  prompt=reduce_prompt
)


In [None]:
# 用于将文档内容插入到 reduce_llm_chain 中，并生成合并的输出。
combine_documents_chain = StuffDocumentsChain(
  llm_chain=reduce_llm_chain,
  document_prompt=document_prompt,
  document_variable_name="context"
)
# 是用于将 map 阶段生成的多个结果合并在一起的链.
reduce_documents_chain = ReduceDocumentsChain(
    combine_documents_chain=combine_documents_chain,
)


In [None]:
chain = MapReduceDocumentsChain(
    llm_chain=llm_chain,
    reduce_documents_chain=reduce_documents_chain,
)


In [None]:

# 准备一组文档
docs = [
    Document(page_content="LangChain 是一个强大的库，帮助你轻松处理文档。"),
    Document(page_content="OpenAI 的 GPT-3 模型提供了强大的文本生成能力。"),
    Document(page_content="通过组合这些工具，你可以实现强大的 NLP 应用。")
]


In [None]:

# 执行链，生成最终总结
result, _ = chain.combine_docs(docs)
print(result)


OSError: Can't load tokenizer for 'gpt2'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'gpt2' is the correct path to a directory containing all relevant files for a GPT2TokenizerFast tokenizer.