In [None]:
import os
from dotenv import load_dotenv

# 加载 .env 文件中的OpenAI API环境变量
load_dotenv()

In [None]:
import os
from langchain.llms import OpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.document_loaders import TextLoader
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS

In [None]:
# 设置OpenAI API密钥
os.environ["OPENAI_API_KEY"] = "your_openai_api_key"

# 加载文档
def load_documents(file_path):
    loader = TextLoader(file_path)
    documents = loader.load()
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
    docs = text_splitter.split_documents(documents)
    return docs

# 创建向量数据库
def create_vectorstore(docs):
    embeddings = OpenAIEmbeddings()
    vectorstore = FAISS.from_documents(docs, embeddings)
    return vectorstore

# 生成多个查询
def generate_multi_queries(llm, original_query):
    prompt_template = """
    请根据以下查询生成3个语义相近但表述不同的查询：
    {original_query}
    """
    prompt = PromptTemplate(input_variables=["original_query"], template=prompt_template)
    chain = LLMChain(llm=llm, prompt=prompt)
    result = chain.run(original_query=original_query)
    queries = result.strip().split('\n')
    return queries

# 多查询检索
def multi_query_retrieval(vectorstore, queries):
    all_results = []
    for query in queries:
        results = vectorstore.similarity_search(query)
        all_results.extend(results)
    # 简单去重（实际应用中可根据需求优化）
    unique_results = []
    for result in all_results:
        if result not in unique_results:
            unique_results.append(result)
    return unique_results

# 主函数
def main():
    # 加载文档
    file_path = "example.txt"
    docs = load_documents(file_path)
    # 创建向量数据库
    vectorstore = create_vectorstore(docs)
    # 初始化语言模型
    llm = OpenAI()
    # 原始查询
    original_query = "请介绍苹果公司的创新成果"
    # 生成多个查询
    queries = generate_multi_queries(llm, original_query)
    print("生成的查询：", queries)
    # 多查询检索
    results = multi_query_retrieval(vectorstore, queries)
    print("检索结果数量：", len(results))
    for result in results:
        print(result.page_content[:100])  # 打印结果的前100个字符

if __name__ == "__main__":
    main()