本项目是一个由langchain + RAG + finetun chatglm3 model充当agent完成下游任务的demo

In [None]:
# 导入必要的包
import requests
import os
from langchain.tools import BaseTool
from langchain.callbacks.manager import AsyncCallbackManagerForToolRun, CallbackManagerForToolRun,
from langchain import LLMChain, PromptTemplate
from langchain.base_language import BaseLanguageModel
from typing import List, Optional, Mapping, Any, Tuple, Union
from functools import partial
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from transformers import AutoModel, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType
import torch
from langchain.agents import AgentExecutor
import faiss
import numpy as np
from transformers import AutoModel,AutoTokenizer
from langchain.schema import AgentAction, AgentFinish
from langchain.agents import BaseSingleActionAgent
from langchain import LLMChain, PromptTemplate
from langchain.base_language import BaseLanguageModel



在本章节，我们将定义一个可由agent调用的tools。
Text_classification_Tool期待调用一个下游finetune好的模型，需要结合context（即finetun过程中的Instruction部分）完成模型的预测。
整个过程区分了APITool和functional_Tool：
    APITool可以用于接入agent可调用api类工具（如搜索）
    functional_Tool可以用于完成下游模型的接入（如文本分类）

In [None]:
class APITool(BaseTool):
    name: str = ""
    description: str = ""
    url: str = ""

    def _call_api(self, query):
        raise NotImplementedError("subclass needs to overwrite this method")

    def _run(
            self,
            query: str,
            run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        return self._call_api(query)

    async def _arun(
            self,
            query: str,
            run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> str:
        raise NotImplementedError("APITool does not support async")

class functional_Tool(BaseTool):
    name: str = ""
    description: str = ""
    url: str = ""

    def _call_func(self, query):
        raise NotImplementedError("subclass needs to overwrite this method")

    def _run(
            self,
            query: str,
            run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        return self._call_func(query)

    async def _arun(
            self,
            query: str,
            run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> str:
        raise NotImplementedError("APITool does not support async")

# search tool #
class SearchTool(APITool):
    llm: BaseLanguageModel

    name = "搜索问答"
    description = "根据用户问题搜索最新的结果，并返回Json格式的结果"

    # search params
    google_api_key: str
    google_cse_id: str
    url = "https://www.googleapis.com/customsearch/v1"
    top_k = 2

    # QA params
    qa_template = """
    请根据下面带```分隔符的文本来回答问题。
    ```{text}```
    问题：{query}
    """
    prompt = PromptTemplate.from_template(qa_template)
    llm_chain: LLMChain = None

    def _call_api(self, query):
        self.get_llm_chain()
        context = self.get_search_result(query)
        resp = self.llm_chain.predict(text=context, query=query)
        return resp

    def get_search_result(self, query):
        data = {"key": self.google_api_key,
                "cx": self.google_cse_id,
                "q": query,
                "lr": "lang_zh-CN"}
        results = requests.get(self.url, params=data).json()
        results = results.get("items", [])[:self.top_k]
        snippets = []
        if len(results) == 0:
            return("No Search Result was found")
        for result in results:
            print("result:", result)
            text = ""
            if "title" in result:
                text += result["title"] + "。"
            if "snippet" in result:
                text += result["snippet"]
            snippets.append(text)
        return("\n\n".join(snippets))

    def get_llm_chain(self):
        if not self.llm_chain:
            self.llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)

class Text_classification_Tool(functional_Tool):
    llm: BaseLanguageModel

    name = "文本分类"
    description = "用户输入句子，完成文本分类"

    # QA params
    qa_template = """
    请根据下面带```分隔符的文本来回答问题。
    ```{text}```
    问题：{query}
    """
    prompt = PromptTemplate.from_template(qa_template)
    llm_chain: LLMChain = None

    def _call_func(self, query) -> str:
        self.get_llm_chain()
        context = "Instruction: 深呼吸，你是一个文本分类模型。你需要根据我的输入给出这句话的情感，候选的情感为：开心、难过、平静"
        resp = self.llm_chain.predict(text=context, query=query)
        return resp

    def get_llm_chain(self):
        if not self.llm_chain:
            self.llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)

在本章节，我们按照langchain的格式定义一个Chatglm3。
其核心函数为_call，内部调用generate_resp生成回复
load_model和load_model_from_checkpoint分别用于加载原始chatglm3 model和一个finetune好的model作为下游agent。
这里的finetune model使用的是 https://github.com/THUDM/ChatGLM3/tree/main/finetune_basemodel_demo 中的方法训练得到的model
训练数据为文本分类数据。

In [None]:
class ChatGLM3(LLM):

    model_path: str
    max_length: int = 8192
    temperature: float = 0.1
    top_p: float = 0.7
    history: List = []
    streaming: bool = True
    model: object = None
    tokenizer: object = None

    @property
    def _llm_type(self) -> str:
        return "chatglm3-6B"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        add_history: bool = False
    ) -> str:
        if self.model is None or self.tokenizer is None:
            raise RuntimeError("Must call `load_model()` to load model and tokenizer!")

        if self.streaming:
            text_callback = partial(StreamingStdOutCallbackHandler().on_llm_new_token, verbose=True)
            resp = self.generate_resp(prompt, text_callback, add_history=add_history)
        else:
            resp = self.generate_resp(self, prompt, add_history=add_history)
        return resp

    def generate_resp(self, prompt, text_callback=None, add_history=True):
        resp = ""
        index = 0
        if text_callback:
            for i, (resp, _) in enumerate(self.model.stream_chat(
                    self.tokenizer,
                    prompt,
                    self.history,
                    max_length=self.max_length,
                    top_p=self.top_p,
                    temperature=self.temperature
            )):
                if add_history:
                    if i == 0:
                        self.history += [[prompt, resp]]
                    else:
                        self.history[-1] = [prompt, resp]
                text_callback(resp[index:])
                index = len(resp)
        else:
            resp, _ = self.model.chat(
                self.tokenizer,
                prompt,
                self.history,
                max_length=self.max_length,
                top_p=self.top_p,
                temperature=self.temperature
            )
            if add_history:
                self.history += [[prompt, resp]]
        return resp

    def load_model(self):
        if self.model is not None or self.tokenizer is not None:
            return
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half().cuda().eval()

    def load_model_from_checkpoint(self, checkpoint=None):
        if self.model is not None or self.tokenizer is not None:
            return
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
        self.model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half()
        peft_config = LoraConfig(
                task_type=TaskType.CAUSAL_LM,
                inference_mode=False,
                r=8,
                target_modules=['query_key_value'],
                lora_alpha=32,
                lora_dropout=0.1,
        )
        self.model = get_peft_model(self.model, peft_config).to("cuda")
        if checkpoint=="text_classification":
            model_dir = "./emo_classifcation/checkpoint/" # 换成自己的finetune后的模型
            peft_path = "{}/chatglm-lora.pt".format(model_dir)
            if os.path.exists(peft_path):
                 self.model.load_state_dict(torch.load(peft_path), strict=False)


在本章节中，我们需要完成一个意图识别agent，用意图识别结果充当路由，链接到下游的的tools。
在本项目中，需要agent能通过输入的文本，判断其是一个文本分类任务。从而选择Text_classification_Tool，这需要一个外挂知识。
这里采用RAG的方式，通过检索出的知识来辅助LLM做意图识别。
包含一个IntentAgent，根据知识库召回的知识填充intent_template
利用select_tools = [(name, resp.index(name)) for name in tool_names if name in resp]判断下游应该使用哪个tools

In [None]:
#RAG模块
def process_data(file_path):
    all_content = []
    files = os.listdir(file_path)
    with open(path,encoding="utf-8") as f:
        lines = f.readlines()
        for content in lines:
            all_content.append(content)
    return all_content

class DFaiss:
    def __init__(self):
        self.index = faiss.IndexFlatL2(4096)
        self.text_str_list = []

    def search(self, emb):
        distance = 100000
        D,I = self.index.search(emb.astype(np.float32), distance)
        content = ""
        for i in range(len(self.text_str_list)):
            if D[0][i] < distance:
                content += self.text_str_list[I[0][i]]
        return content

class emb_model:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b-base", trust_remote_code=True)
        self.model = AutoModel.from_pretrained("THUDM/chatglm3-6b-base", trust_remote_code=True).half().cuda()
        self.myfaiss = DFaiss()

    def retrive(self, text):
        emb = self.get_sentence_emb(text,is_numpy=True)
        retrive_know = self.myfaiss.search(emb)
        return retrive_know
      
    def load_data(self,path):
        all_content = process_data(path)
        for content in all_content:
            self.myfaiss.text_str_list.append(content)
            emb = self.get_sentence_emb(content,is_numpy=True)
            self.myfaiss.index.add(emb.astype(np.float32))

    def get_sentence_emb(self,text,is_numpy=False):
        idx = self.tokenizer([text],return_tensors="pt")
        idx = idx["input_ids"].to("cuda")
        emb = self.model.transformer(idx,return_dict=False)[0]
        emb = emb.transpose(0,1)
        emb = emb[:,-1]

        if is_numpy:
            emb = emb.detach().cpu().numpy()

        return emb

In [None]:
# agent模块
class IntentAgent(BaseSingleActionAgent):
    tools: List
    llm: BaseLanguageModel
    intent_template: str = """
    有一些参考资料，为:{docs}
    你的任务是根据「参考资料」来理解用户问题的意图，并判断该问题属于哪一类意图。
    用户问题：“{query}”
    """
    #intent_template: str = """
    #当需要你对一段文本做情感分类的时候，你的回答应该是：文本分类。你的回答应该是：文本分类。你的回答应该是：文本分类。你的回答应该是：文本分类。你的回答应该是：文本分类。不要会带别的，只回答四个字，文本分类
    #"""

    prompt = PromptTemplate.from_template(intent_template)
    llm_chain: LLMChain = None

    def get_llm_chain(self):
        if not self.llm_chain:
            self.llm_chain = LLMChain(llm=self.llm, prompt=self.prompt)

    def choose_tools(self, query):
        self.get_llm_chain()
        tool_names = [tool.name for tool in self.tools]
        ret_model = emb_model() # RAG
        ret_model.load_data("./doc/")
        docs = ret_model.retrive(query)
        resp = self.llm_chain.predict(query=query, docs=docs)
        select_tools = [(name, resp.index(name)) for name in tool_names if name in resp]
        select_tools.sort(key=lambda x:x[1])
        return [x[0] for x in select_tools]

    @property
    def input_keys(self):
        return ["input"]

    def plan(
            self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
    ) -> Union[AgentAction, AgentFinish]:
        tool_name = self.choose_tools(kwargs["input"])[0]
        return AgentAction(tool=tool_name, tool_input=kwargs["input"], log="")

    async def aplan(
            self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs: Any
    ) -> Union[List[AgentAction], AgentFinish]:
        raise NotImplementedError("IntentAgent does not support async")


In [None]:
调用主函数

In [None]:
# google search api ley
GOOGLE_API_KEY = ""
GOOGLE_CSE_ID = ""
model_path = "THUDM/chatglm3-6b-base"
llm_base = ChatGLM3(model_path=model_path) 
llm_text_cls = ChatGLM3(model_path=model_path)
# 模型加载
llm_base.load_model()
llm_text_cls.load_model_from_checkpoint(checkpoint="text_classification")
# 下游可路由的工具列表
tools = [SearchTool(llm=llm_ori, google_api_key=GOOGLE_API_KEY, google_cse_id=GOOGLE_CSE_ID),
         Text_classification_Tool(llm=llm_text_cls)]
# 意图识别agent使用chatglm3 base充当
agent = IntentAgent(tools=tools, llm=llm_base)
agent_exec = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, max_iterations=1)
agent_exec.run("Input:今天丢失了我的钱包，里面有很重要的东西，心情很沮丧")