In [None]:
#切片
from typing import List

def split_into_chunks(doc_file: str) -> List[str]:
    with open(doc_file, 'r') as file:
        content = file.read()

    return [chunk for chunk in content.split("\n\n")]

chunks = split_into_chunks("doc.md")

for i, chunk in enumerate(chunks):
    print(f"[{i}] {chunk}\n")

In [None]:
# 导入sentence_transformers库
from sentence_transformers import SentenceTransformer
embedding_model = SentenceTransformer("shibing624/text2vec-base-chinese")
# 定义嵌入函数，将文本块转换为向量列表
def embed_chunk(chunk: str) -> list[float]:
    #embedding_model = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
    embedding = embedding_model.encode(chunk, normalize_embeddings=True)
    return embedding.tolist()
# 测试嵌入函数
test_embedding = embed_chunk("测试内容")
print(len(test_embedding))
print(test_embedding)

In [None]:
# 索引化，将文本块转换为向量列表
embeddings = [embed_chunk(chunk) for chunk in chunks]

print(len(embeddings))
vec = embeddings[0]
for i in range(0, len(vec), 10):
    print(vec[i:i+10])

In [None]:
# 导入向量数据库，用于存储和检索向量数据
import chromadb

chromadb_client = chromadb.EphemeralClient()
chromadb_collection = chromadb_client.get_or_create_collection(name="default")

def save_embeddings(chunks: List[str], embeddings: List[List[float]]) -> None:
    for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
        chromadb_collection.add(
            documents=[chunk],
            embeddings=[embedding],
            ids=[str(i)]
        )

save_embeddings(chunks, embeddings)

In [None]:
# 召回
def retrieve(query: str, top_k: int) -> List[str]:
    query_embedding = embed_chunk(query)
    results = chromadb_collection.query(
        query_embeddings=[query_embedding],
        n_results=top_k
    )
    return results['documents'][0]

query = "哆啦A梦复制斗篷有什么用？"
retrieved_chunks = retrieve(query, 5)

for i, chunk in enumerate(retrieved_chunks):
    print(f"[{i}] {chunk}\n")

In [None]:
# 重排
from sentence_transformers import CrossEncoder

def rerank(query: str, retrieved_chunks: List[str], top_k: int) -> List[str]:
    cross_encoder = CrossEncoder('cross-encoder/mmarco-mMiniLMv2-L12-H384-v1')
    pairs = [(query, chunk) for chunk in retrieved_chunks]
    scores = cross_encoder.predict(pairs)

    scored_chunks = list(zip(retrieved_chunks, scores))
    scored_chunks.sort(key=lambda x: x[1], reverse=True)

    return [chunk for chunk, _ in scored_chunks][:top_k]

reranked_chunks = rerank(query, retrieved_chunks, 3)

for i, chunk in enumerate(reranked_chunks):
    print(f"[{i}] {chunk}\n")

In [None]:
# 使用DeepSeek的API进行RAG
import requests
from dotenv import load_dotenv
import os
from typing import List

def validate_api_key():
    """验证 API 密钥是否已设置"""
    if not DEEPSEEK_API_KEY:
        raise ValueError(
            "DeepSeek API 密钥未设置！请确保：\n"
            "1. 在项目根目录创建 .env 文件\n"
            "2. 在 .env 文件中添加：DEEPSEEK_API_KEY=你的API密钥\n"
            "3. 确保 API 密钥格式正确且未过期"
        )

# 加载环境变量并验证
load_dotenv()
DEEPSEEK_API_KEY = os.getenv('DEEPSEEK_API_KEY')
DEEPSEEK_API_URL = "https://api.deepseek.com/v1/chat/completions"

def generate(query: str, chunks: List[str]) -> str:
    try:
        # 验证 API 密钥
        validate_api_key()
        
        # 构建提示词
        prompt = f"""你是一位知识助手，请根据用户的问题和下列片段生成准确的回答。

用户问题: {query}

相关片段:
{"".join(chunks)}
请基于上述内容作答，不要编造信息。"""

        print("正在生成回答...\n")
        print(f"提示词内容:\n{prompt}\n")

        # 设置请求头
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {DEEPSEEK_API_KEY}"
        }

        # 构建请求体
        data = {
            "model": "deepseek-chat",  # 使用正确的模型名称
            "messages": [
                {
                    "role": "user",
                    "content": prompt
                }
            ],
            "temperature": 0.7,
            "max_tokens": 2000
        }

        # 发送 POST 请求
        response = requests.post(DEEPSEEK_API_URL, headers=headers, json=data)
        
        # 如果是认证错误，给出更详细的错误信息
        if response.status_code == 401:
            print("认证错误！请检查：")
            print("1. API 密钥是否正确设置在 .env 文件中")
            print("2. API 密钥是否有效（未过期）")
            print("3. API 密钥格式是否正确")
            print(f"错误详情: {response.text}")
            return ""
            
        # 检查其他错误
        response.raise_for_status()
        
        # 解析 JSON 响应
        result = response.json()
        return result['choices'][0]['message']['content']
        
    except requests.RequestException as e:
        print(f"API 请求错误: {e}")
        if hasattr(e, 'response') and e.response is not None:
            print(f"错误详情: {e.response.text}")
        return ""
    except (KeyError, IndexError) as e:
        print(f"解析响应出错: {e}")
        return ""
    except ValueError as e:
        print(e)
        return ""
answer = generate(query, reranked_chunks)
print(answer)

In [None]:
# from dotenv import load_dotenv
# import google.generativeai as genai
# import os

# load_dotenv()
# # 从环境变量中获取 API 密钥
# api_key = os.getenv('GEMINI_API_KEY')
# # 配置 API 密钥
# genai.configure(api_key=api_key)

# def generate(query: str, chunks: list[str]) -> str:
#     prompt = f"""你是一位知识助手，请根据用户的问题和下列片段生成准确的回答。
# 用户问题: {query}
# 相关片段:
# {"".join(chunks)}
# 请基于上述内容作答，不要编造信息。"""
#     print(f"{prompt}\n\n---\n")
#     # 创建模型实例
#     model = genai.GenerativeModel("gemini-2.5-flash")
#     # 生成内容
#     response = model.generate_content(prompt)
#     # 返回响应文本，如果没有则返回空字符串
#     return response.text if response.text else ""

# answer = generate(query, reranked_chunks)
# print(answer)