## 开源大模型结合外部知识库的自动问答

开源大语言模型有着部署成本低，输出可控等优点。不过对于一些较小参数的模型，例如 Llama2-7B，Zephyr-7B 来说，在回答一些很细节的问题时经常会出现幻觉。这些幻觉会影响模型最终输出的准确性。因此，我们需要将外部知识库引入到生成过程中，提高生成内容的准确度和可信度。

### 安装依赖

In [None]:
%pip install langchain langchain-experimental text_generation InstructorEmbedding replicate --upgrade

In [None]:
# %pip install getpass
%pip install sqlalchemy==1.4.48
%pip install clickhouse-sqlalchemy==0.2.4

In [1]:
from sqlalchemy import __version__
print(__version__)

1.4.48


### 使用 Huggingface 的开源模型和推理资源

In [41]:
from typing import Dict
from langchain.prompts import PromptTemplate
from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
import json, logging

sagemaker_endpoint_name = 'mt-djl-baichuan2-13b-4bits-g5'

class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:        
        input = {"inputs": prompt, "parameters": {"temperature": 0.01, "max_new_tokens":1024} }
        logging.info("prompt: %s", prompt)
        input_str = json.dumps(input, ensure_ascii=False)
        return input_str.encode('utf-8')
    
    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        logging.info("response_json: %s", response_json)
        return response_json["outputs"]


baichuan_model = SagemakerEndpoint(
    endpoint_name=sagemaker_endpoint_name, 
    region_name="us-east-1",
    content_handler=ContentHandler()
)


def generate_prompt_baichuan2(input, instruct, history=[]):
    messages = []
    if instruct:
        messages.append({"role": 'system', "content": instruct})
    for msg in history:
        if isinstance(msg, HumanMessage):
            messages.append({"role": 'user', "content": msg.content})
        elif isinstance(msg, AIMessage):
            messages.append({"role": 'assistant', "content": msg.content})
    messages.append({"role": "user", "content": input})
    return messages


In [42]:
payload = generate_prompt_baichuan2('When did Geoffrey Hinton born?', instruct=None)
prompt = json.dumps(payload)
baichuan_model(prompt)

'Geoffrey Hinton was born on February 6, 1947.'

通过维基百科我们可以知道：
<iframe
	src="https://en.wikipedia.org/wiki/Geoffrey_Hinton"
	frameborder="0"
	width="1080"
	height="500"
></iframe>

Geoffrey Hinton 的生日并不是 3 月 12 日，因此我们需要外部知识的帮助。

### 构建搜索：创建 Embedding 模型

In [43]:
from langchain.embeddings import SentenceTransformerEmbeddings

emb_model = SentenceTransformerEmbeddings(
    model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
)


### 构建搜索：连接数据库

In [44]:
from sqlalchemy import create_engine, MetaData

MYSCALE_USER = "chatdata"
MYSCALE_PASSWORD = "myscale_rocks"
MYSCALE_HOST = "msc-1decbcc9.us-east-1.aws.staging.myscale.cloud"
MYSCALE_PORT = 443

engine = create_engine(
    f"clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}/wiki?protocol=https"
)
metadata = MetaData(bind=engine)


### 构建搜索：构建查询构造器

#### 关于 `Vector SQL`

<img src="https://myscale.com/blog/assets/img/pipeline.015b6008.png" height="300px">

由于带有向量搜索的 SQL 与常规 SQL 非常相似，我们可以让大语言模型来生成一个向量搜索的中间形式：也就是 `Vector SQL`

```sql
SELECT * FROM table
ORDER BY DISTANCE(vector, NeuralArray(flower))
LIMIT 10
```

通过 prompt 我们可以让语言模型学会使用距离函数 `DISTANCE` 和 文本特征提取函数 `NeuralArray`
与此同时还可以让它学会随意组合不同的过滤条件。这样就可以更加自动地构建用户所期望的搜索查询了。

下面是我们用来将语言模型输出转化为向量搜索 SQL 的代码：


In [45]:
from typing import List, Dict, Any
from langchain_experimental.sql.vector_sql import VectorSQLOutputParser


class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
    """Based on VectorSQLOutputParser
    It also modify the SQL to get all columns
    """

    must_have_columns: List[str]

    @property
    def _type(self) -> str:
        return "vector_sql_retrieve_custom"

    def parse(self, text: str) -> Dict[str, Any]:
        text = [l for l in text.strip().split('\n') if len(l) > 2][0]
        start = text.upper().find("SELECT")
        if start >= 0:
            end = text.upper().find("FROM")
            text = text.replace(
                text[start + len("SELECT") + 1: end - 1],
                ", ".join(self.must_have_columns),
                1
            )
        qstr = super().parse(text)
        return qstr


In [46]:
import prompts
import importlib

importlib.reload(prompts)

<module 'prompts' from '/home/ec2-user/SageMaker/workshops/myscale-aws-workshop/chat_with_database/prompts.py'>

In [47]:
from langchain.prompts import StringPromptTemplate
from langchain.schema.prompt_template import BasePromptTemplate
from pydantic import BaseModel, validator
from langchain.schema.messages import HumanMessage, AIMessage
from prompts import _myscale_prompt

class Baichuan2PromptTemplate(StringPromptTemplate, BaseModel):
    """A custom prompt template that takes in the function name as input, and formats the prompt template to provide the source code of the function."""

    def format(self, **kwargs) -> str:

        # "question", "table_info", "top_k"
        question = kwargs["input"]
        table_info = kwargs["table_info"]
        top_k = kwargs["top_k"]

    
        full_instruct = _myscale_prompt.format(input=question, table_info=table_info, top_k=top_k)
        
        prompt_json = generate_prompt_baichuan2(full_instruct, instruct=None)
        return json.dumps(prompt_json, ensure_ascii=False)

    def _prompt_type(self):
        return "baichuan2-prompt"

### 构建搜索：集成 LLM 与 数据库

In [48]:
from langchain.prompts import PromptTemplate
from langchain.sql_database import SQLDatabase
from langchain_experimental.retrievers.vector_sql_database import (
    VectorSQLDatabaseChainRetriever,
)
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
from langchain.llms import HuggingFaceTextGenInference


must_have_cols = ['id', 'title', 'url', 'text', 'views']

PROMPT = Baichuan2PromptTemplate(input_variables=["input", "table_info", "top_k"])

output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
    model=emb_model, must_have_columns=must_have_cols
)

sql_query_chain = VectorSQLDatabaseChain.from_llm(
    llm=baichuan_model,
    prompt=PROMPT,
    top_k=10,
    return_direct=True,
    db=SQLDatabase(engine, None, metadata, max_string_length=1024),
    sql_cmd_parser=output_parser,
    native_format=True,
)
sql_retriever = VectorSQLDatabaseChainRetriever(
    sql_db_chain=sql_query_chain, page_content_key="text"
)


### 执行查询

In [49]:
from langchain.callbacks import StdOutCallbackHandler

docs = sql_retriever.get_relevant_documents("When did Geoffrey Hinton born?",
                                            callbacks=[StdOutCallbackHandler()])
docs




[1m> Entering new VectorSQLDatabaseChain chain...[0m
When did Geoffrey Hinton born?
SQLQuery:

[1m> Entering new LLMChain chain...[0m
Prompt after formatting:

[1m> Finished chain.[0m
[32;1m[1;3mSELECT Wikipedia.title, Wikipedia.id, Wikipedia.emb FROM Wikipedia WHERE has(authors, 'Geoffrey Hinton') AND emb IS NOT NULL[0m

DatabaseException: Orig exception: Code: 47. DB::Exception: Missing columns: 'authors' while processing query: 'SELECT id, title, url, text, views FROM Wikipedia WHERE has(authors, 'Geoffrey Hinton') AND (emb IS NOT NULL)', required columns: 'url' 'title' 'views' 'id' 'text' 'authors' 'emb', maybe you meant: 'url', 'title', 'views', 'id', 'text' or 'emb'. (UNKNOWN_IDENTIFIER) (version 23.3.2.1)


### 构建 RAG：将外部知识连接至生成提示中

首先，我们使用了之前构造好的 VectorSQL 检索器。同时我们使用提示模板将他们整理好嵌入进生成提示中。

我们这里使用了 LangChain 的 `RetrievalQAwithSources` 提示链。

In [50]:
from langchain import LLMChain
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
from langchain.chains.combine_documents.stuff import StuffDocumentsChain

combine_prompt_template = (
    "You are a helpful document assistant. Your task is to answer any questions "
    + "related to the given documents. You should use the title and abstract of the selected documents as your source of information "
    + "and try to provide concise and accurate answers to any questions asked by the user. If you are unable to find "
    + "relevant information in the given sections, you will need to let the user know that the source does not contain "
    + "relevant information but still try to provide an answer based on your general knowledge. The following is the related information "
    + "about the document that will help you answer users' questions.\nHere the contexts:\n{summaries}\n\n\nQuestion: {question}"
    + "\nAnswer: "
)

COMBINE_PROMPT = PromptTemplate(
    input_variables=["summaries", "question"], template=combine_prompt_template)

doc_prompt = PromptTemplate(
            input_variables=["page_content", "url", "title"],
            template="Title: {title}\nContent: {page_content}\nSOURCE: {url}")

chain = RetrievalQAWithSourcesChain(
    retriever=sql_retriever,
    combine_documents_chain=StuffDocumentsChain(
        llm_chain=LLMChain(
            prompt=COMBINE_PROMPT,
            llm=baichuan_model,
        ),
        document_prompt=doc_prompt,
        document_variable_name="summaries",

    ),
    return_source_documents=True,
    max_tokens_limit=12000,
)


In [51]:
chain("When did Geoffrey Hinton born?", callbacks=[StdOutCallbackHandler()])



[1m> Entering new RetrievalQAWithSourcesChain chain...[0m


[1m> Entering new VectorSQLDatabaseChain chain...[0m
When did Geoffrey Hinton born?
SQLQuery:

[1m> Entering new LLMChain chain...[0m
Prompt after formatting:

[1m> Finished chain.[0m
[32;1m[1;3mSELECT Wikipedia.title, Wikipedia.id, Wikipedia.emb FROM Wikipedia WHERE has(authors, 'Geoffrey Hinton') AND emb IS NOT NULL[0m

DatabaseException: Orig exception: Code: 47. DB::Exception: Missing columns: 'authors' while processing query: 'SELECT id, title, url, text, views FROM Wikipedia WHERE has(authors, 'Geoffrey Hinton') AND (emb IS NOT NULL)', required columns: 'url' 'title' 'views' 'id' 'text' 'authors' 'emb', maybe you meant: 'url', 'title', 'views', 'id', 'text' or 'emb'. (UNKNOWN_IDENTIFIER) (version 23.3.2.1)


In [44]:
import ast

test = 'okokoko'
test = "{'oko': 'vv11',}"

if '{' in test and  '}' in test:
    print(ast.literal_eval(test))
else:
    print(test)

{'oko': 'vv11'}
