# Self-RAG：一种动态的 RAG 方法

在本笔记中，我实现了 Self-RAG，这是一个先进的 RAG 系统，它可以动态决定何时以及如何使用检索到的信息。与传统的 RAG 方法不同，Self-RAG 在整个检索和生成过程中引入了反思点，从而产生更高质量和更可靠的响应。

## Self-RAG 的关键组成部分

1.  **检索决策**：确定对于给定的查询是否需要进行检索。
2.  **文档检索**：在需要时获取可能相关的文档。
3.  **相关性评估**：评估每个检索到的文档的相关程度。
4.  **响应生成**：基于相关上下文创建响应。
5.  **支持度评估**：评估响应是否恰当地基于上下文。
6.  **实用性评估**：对生成的响应的整体有用性进行评分。

导入库

In [None]:
import pymupdf
import os
import numpy as np
import json
import openai
from tqdm import tqdm
import re
from datetime import datetime

提取文本

In [None]:
def extract_text_from_pdf(pdf_path):
    """
    提取PDF文件中的文本并打印前`num_chars`个字符。

    参数：
    pdf_path (str): PDF文件的路径。

    返回：
    str: 从PDF中提取的文本。

    """
    # 打开PDF文件
    mypdf = pymupdf.open(pdf_path)
    all_text = ""  # 初始化一个空字符串来存储提取的文本

    # 迭代PDF中的每个页面
    for page_num in range(mypdf.page_count):
        page = mypdf[page_num]  # 获取页面
        text = page.get_text("text")  # 从页面中提取文本
        all_text += text  # 将提取的文本附加到all_text字符串

    return all_text  # 返回提取的文本

pdf_path = "data/AI_Information.pdf"


extracted_text = extract_text_from_pdf(pdf_path)

print(extracted_text[:500])

分块

In [None]:
def chunk_text(text, n, overlap):
    """
    将文本分割为多个块，每个块的大小为n，重叠部分为overlap。
    参数：
    text: 输入的文本
    n: 每个块的大小
    overlap: 相邻块之间的重叠部分大小

    返回：
    文本块列表
    """
    chunks = []  
    for i in range(0, len(text), n - overlap):
        
        chunks.append(text[i:i + n])
    
    return chunks  

配置client

In [None]:
client = openai.OpenAI(
    api_key=os.getenv("DASHSCOPE_API_KEY"),  # 如果你没有配置环境变量，请在此处用你的API Key进行替换
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"  # 百炼服务的base_url
)

简易向量库

In [None]:
class SimpleVectorStore:
    """
    简易的向量存储库。
    """
    def __init__(self):
        
        self.vectors = []
        self.texts = []
        self.metadata = []
    
    def add_item(self, text, embedding, metadata=None):
        """
        添加一个新的项到存储库。

        参数:
        text (str): 文本内容。
        embedding (List[float]): 文本的嵌入向量。
        metadata (Dict, optional): 与文本相关的元数据。
        """
        self.vectors.append(np.array(embedding))
        self.texts.append(text)
        self.metadata.append(metadata or {})
    
    def similarity_search(self, query_embedding, k=5):
        """
        查找与查询嵌入向量最相似的文本。

        参数:
        query_embedding (List[float]): 查询的嵌入向量。
        k (int, optional): 返回最相似的k个结果。

        返回:
        List[Dict]: 最相似的文本及其相关信息。
        """
        if not self.vectors:
            return []
        

        query_vector = np.array(query_embedding)
        

        similarities = []
        for i, vector in enumerate(self.vectors):
            similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
            similarities.append((i, similarity))
        

        similarities.sort(key=lambda x: x[1], reverse=True)
        

        results = []
        for i in range(min(k, len(similarities))):
            idx, score = similarities[i]
            results.append({
                "text": self.texts[idx],
                "metadata": self.metadata[idx],
                "similarity": score
            })
        
        return results

创建向量

In [None]:
def create_embeddings_in_batches(text_chunks, model="text-embedding-v3", batch_size_limit=10): # 我改成了官方模型名，你可以换回 "text-embedding-v3"
    """
    调用 OpenAI 的 Embedding API 来创建文本列表的嵌入向量，处理批处理大小限制。

    参数:
    text_chunks (List[str]): 需要创建嵌入的文本字符串列表。
    model (str): 使用的嵌入模型。
    batch_size_limit (int): API 允许的最大批处理大小。根据错误信息，这里是10。

    返回:
    List[List[float]]: 所有文本的嵌入向量列表。
    """
    all_embeddings = []
    if not text_chunks:
        return []

    if not isinstance(text_chunks, list): # 确保输入是列表
        text_chunks = [text_chunks]

    for i in range(0, len(text_chunks), batch_size_limit):
        batch = text_chunks[i:i + batch_size_limit]
        try:
            #print(f"Processing batch {i//batch_size_limit + 1}, size: {len(batch)}")
            response = client.embeddings.create(
                input=batch,
                model=model,
                encoding_format="float"
            )
            # 从响应中提取该批次的嵌入向量
            batch_embeddings = [item.embedding for item in response.data]
            all_embeddings.extend(batch_embeddings)


        except Exception as e:
            print(f"Error processing batch starting with chunk: '{batch[0][:50]}...'")
            print(f"API Error: {e}")

            raise e 

    return all_embeddings

def create_embeddings(text, model="text-embedding-v3"):
    """
    字符串向量化
    参数:
    text (str): 需要创建嵌入的文本字符串。
    model (str): 使用的嵌入模型。

    返回:
    List[float]: 文本的嵌入向量。
    """
    response = client.embeddings.create(
        model=model,
        input=text
    )

    return response.data[0].embedding

文本处理流程

In [None]:
def process_document(pdf_path, chunk_size=1000, chunk_overlap=200):
    """
    处理带有反馈循环的RAG（检索增强生成）文档。
    此函数处理完整的文档处理管道：
    1、从PDF中提取文本
    2、重叠文本分块
    3、嵌入区块创建
    4、矢量数据库元数据存储
    """
    
    print("Extracting text from PDF...")
    extracted_text = extract_text_from_pdf(pdf_path)
    
    print("Chunking text...")
    chunks = chunk_text(extracted_text, chunk_size, chunk_overlap)
    print(f"Created {len(chunks)} text chunks")
    

    print("Creating embeddings for chunks...")
    chunk_embeddings = create_embeddings_in_batches(chunks)
    
    store = SimpleVectorStore()

    for i, (chunk, embedding) in enumerate(zip(chunks, chunk_embeddings)):
        store.add_item(
            text=chunk,
            embedding=embedding,
            metadata={
                "index": i,                
                "source": pdf_path,     
                "relevance_score": 1.0,   
                "feedback_count": 0        
            }
        )
    
    print(f"Added {len(chunks)} chunks to the vector store")
    return chunks

self-Rag 组件

In [None]:
def determine_if_retrieval_needed(query):

    system_prompt = """您是一个AI助手，负责确定回答查询是否需要检索。
        对于事实问题、特定信息请求或有关事件、人员或概念的问题，请回答“是”。
        对于观点、假设情景或常识性的简单查询，回答“否”。
        只回答“yes”或“no”。"""


    user_prompt = f"Query: {query}\n\nIs retrieval necessary to answer this query accurately?"
    
    response = client.chat.completions.create(
        model="meta-llama/Llama-3.2-3B-Instruct",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0
    )
    
    answer = response.choices[0].message.content.strip().lower()

    return "yes" in answer

评估相关性

In [None]:
def evaluate_relevance(query, context):
    """
    Evaluates the relevance of a context to the query.
    
    Args:
        query (str): User query
        context (str): Context text
        
    Returns:
        str: 'relevant' or 'irrelevant'
    """
    # System prompt to instruct the AI on how to determine document relevance
    system_prompt = """You are an AI assistant that determines if a document is relevant to a query.
    Consider whether the document contains information that would be helpful in answering the query.
    Answer with ONLY "Relevant" or "Irrelevant"."""

    # Truncate context if it is too long to avoid exceeding token limits
    max_context_length = 2000
    if len(context) > max_context_length:
        context = context[:max_context_length] + "... [truncated]"

    # User prompt containing the query and the document content
    user_prompt = f"""Query: {query}
    Document content:
    {context}

    Is this document relevant to the query? Answer with ONLY "Relevant" or "Irrelevant".
    """
    
    # Generate response from the model
    response = client.chat.completions.create(
        model="meta-llama/Llama-3.2-3B-Instruct",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0
    )
    
    # Extract the answer from the model's response and convert to lowercase
    answer = response.choices[0].message.content.strip().lower()
    
    return answer  # Return the relevance evaluation

In [None]:
def assess_support(response, context):
    """
    Assesses how well a response is supported by the context.
    
    Args:
        response (str): Generated response
        context (str): Context text
        
    Returns:
        str: 'fully supported', 'partially supported', or 'no support'
    """
    # System prompt to instruct the AI on how to evaluate support
    system_prompt = """You are an AI assistant that determines if a response is supported by the given context.
    Evaluate if the facts, claims, and information in the response are backed by the context.
    Answer with ONLY one of these three options:
    - "Fully supported": All information in the response is directly supported by the context.
    - "Partially supported": Some information in the response is supported by the context, but some is not.
    - "No support": The response contains significant information not found in or contradicting the context.
    """

    # Truncate context if it is too long to avoid exceeding token limits
    max_context_length = 2000
    if len(context) > max_context_length:
        context = context[:max_context_length] + "... [truncated]"

    # User prompt containing the context and the response to be evaluated
    user_prompt = f"""Context:
    {context}

    Response:
    {response}

    How well is this response supported by the context? Answer with ONLY "Fully supported", "Partially supported", or "No support".
    """
    
    # Generate response from the model
    response = client.chat.completions.create(
        model="meta-llama/Llama-3.2-3B-Instruct",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0
    )
    
    # Extract the answer from the model's response and convert to lowercase
    answer = response.choices[0].message.content.strip().lower()
    
    return answer  # Return the support assessment

In [None]:
def rate_utility(query, response):
    """
    Rates the utility of a response for the query.
    
    Args:
        query (str): User query
        response (str): Generated response
        
    Returns:
        int: Utility rating from 1 to 5
    """
    # System prompt to instruct the AI on how to rate the utility of the response
    system_prompt = """You are an AI assistant that rates the utility of a response to a query.
    Consider how well the response answers the query, its completeness, correctness, and helpfulness.
    Rate the utility on a scale from 1 to 5, where:
    - 1: Not useful at all
    - 2: Slightly useful
    - 3: Moderately useful
    - 4: Very useful
    - 5: Exceptionally useful
    Answer with ONLY a single number from 1 to 5."""

    # User prompt containing the query and the response to be rated
    user_prompt = f"""Query: {query}
    Response:
    {response}

    Rate the utility of this response on a scale from 1 to 5:"""
    
    # Generate the utility rating using the OpenAI client
    response = client.chat.completions.create(
        model="meta-llama/Llama-3.2-3B-Instruct",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0
    )
    
    # Extract the rating from the model's response
    rating = response.choices[0].message.content.strip()
    
    # Extract just the number from the rating
    rating_match = re.search(r'[1-5]', rating)
    if rating_match:
        return int(rating_match.group())  # Return the extracted rating as an integer
    
    return 3  # Default to middle rating if parsing fails

In [None]:
def generate_response(query, context=None):
    """
    Generates a response based on the query and optional context.
    
    Args:
        query (str): User query
        context (str, optional): Context text
        
    Returns:
        str: Generated response
    """
    # System prompt to instruct the AI on how to generate a helpful response
    system_prompt = """You are a helpful AI assistant. Provide a clear, accurate, and informative response to the query."""
    
    # Create the user prompt based on whether context is provided
    if context:
        user_prompt = f"""Context:
        {context}

        Query: {query}

        Please answer the query based on the provided context.
        """
    else:
        user_prompt = f"""Query: {query}
        
        Please answer the query to the best of your ability."""
    
    # Generate the response using the OpenAI client
    response = client.chat.completions.create(
        model="meta-llama/Llama-3.2-3B-Instruct",
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=0.2
    )
    
    # Return the generated response text
    return response.choices[0].message.content.strip()

In [None]:
def self_rag(query, vector_store, top_k=3):
    """
    Implements the complete Self-RAG pipeline.
    
    Args:
        query (str): User query
        vector_store (SimpleVectorStore): Vector store containing document chunks
        top_k (int): Number of documents to retrieve initially
        
    Returns:
        dict: Results including query, response, and metrics from the Self-RAG process
    """
    print(f"\n=== Starting Self-RAG for query: {query} ===\n")
    
    # Step 1: Determine if retrieval is necessary
    print("Step 1: Determining if retrieval is necessary...")
    retrieval_needed = determine_if_retrieval_needed(query)
    print(f"Retrieval needed: {retrieval_needed}")
    
    # Initialize metrics to track the Self-RAG process
    metrics = {
        "retrieval_needed": retrieval_needed,
        "documents_retrieved": 0,
        "relevant_documents": 0,
        "response_support_ratings": [],
        "utility_ratings": []
    }
    
    best_response = None
    best_score = -1
    
    if retrieval_needed:
        # Step 2: Retrieve documents
        print("\nStep 2: Retrieving relevant documents...")
        query_embedding = create_embeddings(query)
        results = vector_store.similarity_search(query_embedding, k=top_k)
        metrics["documents_retrieved"] = len(results)
        print(f"Retrieved {len(results)} documents")
        
        # Step 3: Evaluate relevance of each document
        print("\nStep 3: Evaluating document relevance...")
        relevant_contexts = []
        
        for i, result in enumerate(results):
            context = result["text"]
            relevance = evaluate_relevance(query, context)
            print(f"Document {i+1} relevance: {relevance}")
            
            if relevance == "relevant":
                relevant_contexts.append(context)
        
        metrics["relevant_documents"] = len(relevant_contexts)
        print(f"Found {len(relevant_contexts)} relevant documents")
        
        if relevant_contexts:
            # Step 4: Process each relevant context
            print("\nStep 4: Processing relevant contexts...")
            for i, context in enumerate(relevant_contexts):
                print(f"\nProcessing context {i+1}/{len(relevant_contexts)}...")
                
                # Generate response based on the context
                print("Generating response...")
                response = generate_response(query, context)
                
                # Assess how well the response is supported by the context
                print("Assessing support...")
                support_rating = assess_support(response, context)
                print(f"Support rating: {support_rating}")
                metrics["response_support_ratings"].append(support_rating)
                
                # Rate the utility of the response
                print("Rating utility...")
                utility_rating = rate_utility(query, response)
                print(f"Utility rating: {utility_rating}/5")
                metrics["utility_ratings"].append(utility_rating)
                
                # Calculate overall score (higher for better support and utility)
                support_score = {
                    "fully supported": 3, 
                    "partially supported": 1, 
                    "no support": 0
                }.get(support_rating, 0)
                
                overall_score = support_score * 5 + utility_rating
                print(f"Overall score: {overall_score}")
                
                # Keep track of the best response
                if overall_score > best_score:
                    best_response = response
                    best_score = overall_score
                    print("New best response found!")
        
        # If no relevant contexts were found or all responses scored poorly
        if not relevant_contexts or best_score <= 0:
            print("\nNo suitable context found or poor responses, generating without retrieval...")
            best_response = generate_response(query)
    else:
        # No retrieval needed, generate directly
        print("\nNo retrieval needed, generating response directly...")
        best_response = generate_response(query)
    
    # Final metrics
    metrics["best_score"] = best_score
    metrics["used_retrieval"] = retrieval_needed and best_score > 0
    
    print("\n=== Self-RAG Completed ===")
    
    return {
        "query": query,
        "response": best_response,
        "metrics": metrics
    }