In [None]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_community.document_loaders import TextLoader
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.llms import Tongyi
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate

In [11]:
model = Tongyi(model='qwen-plus')  # type: ignore

In [12]:
# 自定义的文本拆分器
text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=30, separators=['第'])
loader = TextLoader('../../data/rag/extracted_text.txt', encoding='utf-8')

docs = loader.load_and_split(text_splitter=text_splitter)

In [4]:
vectorstore = Chroma.from_documents(documents=docs, embedding=DashScopeEmbeddings(model='text-embedding-v3'))


In [5]:
retriever = vectorstore.as_retriever(search_type='similarity', search_kwargs={'k': 10})


In [6]:
template = '你是供电局的客服，请根据知识库，回答用户的问题，并列出引用自第几条 和第几点。问题：{question}。知识库：{docs}。不要输出答案如下之类的前言。'
prompt = PromptTemplate(template=template, input_variables=['question', 'docs'])
rag_chain = prompt | model | StrOutputParser()

In [7]:
def format_docs(docs):
    return '\n\n'.join(doc.page_content for doc in docs)


def ask_rag(question):
    formatted_docs = (retriever | format_docs).invoke(question)
    result = rag_chain.invoke({'question': question, 'docs': formatted_docs})
    return result

In [8]:
# ask_rag('用户集资建设的供电站，建成运营前，由谁管理')

In [None]:
from queue import Empty, Queue
from threading import Thread

from langchain.callbacks.base import BaseCallbackHandler
from langchain.prompts import PromptTemplate


class QueueCallbackHandler(BaseCallbackHandler):
    def __init__(self):
        self.q = Queue()
        self.done = False

    def on_llm_new_token(self, token: str, **kwargs):
        self.q.put(token)

    def get_generator(self):
        while not self.done or not self.q.empty():
            try:
                token = self.q.get(timeout=0.1)
                yield token
            except Empty:
                continue

    def finish(self):
        self.done = True


def stream_chain(prompt_text: str):
    handler = QueueCallbackHandler()
    formatted_docs = (retriever | format_docs).invoke(prompt_text)

    # 用线程非阻塞执行 chain.invoke（因为 streaming 是回调）
    def run_chain():
        rag_chain.stream({'question': prompt_text, 'docs': formatted_docs})
        handler.finish()

    Thread(target=run_chain).start()

    return handler.get_generator()

In [13]:
stream_chain('用户集资建设的供电站，建成运营前，由谁管理')

<generator object QueueCallbackHandler.get_generator at 0x000001E85DE7FC60>