# 用于增强 RAG 的相关片段提取 (RSE)

相关片段提取 (RSE)技术，用以提升我们 RAG 系统中的上下文质量，不再是简单地检索一堆孤立的文本块，而是识别并重构出连续的文本片段，从而为我们的语言模型提供更优质的上下文。

## 核心概念

在文档中，相关的文本块往往会聚集在一起。通过识别这些集群并保持其连续性，我们可以为大语言模型（LLM）的工作提供更连贯的上下文。

导入必要的库

In [81]:
import pymupdf
import os
import numpy as np
import json
import openai
from tqdm import tqdm
import re

从pdf提取文本

In [82]:
def extract_text_from_pdf(pdf_path):
    """
    提取PDF文件中的文本并打印前`num_chars`个字符。

    参数：
    pdf_path (str): PDF文件的路径。

    返回：
    str: 从PDF中提取的文本。

    """
    # 打开PDF文件
    mypdf = pymupdf.open(pdf_path)
    all_text = ""  # 初始化一个空字符串来存储提取的文本

    # 迭代PDF中的每个页面
    for page_num in range(mypdf.page_count):
        page = mypdf[page_num]  # 获取页面
        text = page.get_text("text")  # 从页面中提取文本
        all_text += text  # 将提取的文本附加到all_text字符串

    return all_text  # 返回提取的文本

pdf_path = "data/AI_Information.pdf"


extracted_text = extract_text_from_pdf(pdf_path)

print(extracted_text[:500])

Understanding Artificial Intelligence 
Chapter 1: Introduction to Artificial Intelligence 
Artificial intelligence (AI) refers to the ability of a digital computer or computer-controlled robot 
to perform tasks commonly associated with intelligent beings. The term is frequently applied to 
the project of developing systems endowed with the intellectual processes characteristic of 
humans, such as the ability to reason, discover meaning, generalize, or learn from past 
experience. Over the past f


分块

In [83]:
def chunk_text(text, n, overlap):
    """
    将文本分割为多个块，每个块的大小为n，重叠部分为overlap。
    参数：
    text: 输入的文本
    n: 每个块的大小
    overlap: 相邻块之间的重叠部分大小

    返回：
    文本块列表
    """
    chunks = []  
    for i in range(0, len(text), n - overlap):
        
        chunks.append(text[i:i + n])
    
    return chunks  

配置client

In [84]:
client = openai.OpenAI(
    api_key=os.getenv("DASHSCOPE_API_KEY"),  # 如果您没有配置环境变量，请在此处用您的API Key进行替换
    base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"  # 百炼服务的base_url
)

简易的向量库

In [85]:
class SimpleVectorStore:
    """
    简易的向量存储库。
    """
    def __init__(self):
        
        self.vectors = []
        self.texts = []
        self.metadata = []
        self.documents = []
    def add_documents(self, documents, vectors=None, metadata=None):
        """
        Add documents to the vector store.
        
        Args:
            documents (List[str]): List of document chunks
            vectors (List[List[float]], optional): List of embedding vectors
            metadata (List[Dict], optional): List of metadata dictionaries
        """
        if vectors is None:
            vectors = [None] * len(documents)
        
        if metadata is None:
            metadata = [{} for _ in range(len(documents))]
        
        for doc, vec, meta in zip(documents, vectors, metadata):
            self.documents.append(doc)
            self.vectors.append(vec)
            self.metadata.append(meta)
    
    def add_item(self, text, embedding, metadata=None):
        """
        添加一个新的项到存储库。

        参数:
        text (str): 文本内容。
        embedding (List[float]): 文本的嵌入向量。
        metadata (Dict, optional): 与文本相关的元数据。
        """
        self.vectors.append(np.array(embedding))
        self.texts.append(text)
        self.metadata.append(metadata or {})
    
    def search(self, query_vector, top_k=5):
        """
        Search for most similar documents.
        
        Args:
            query_vector (List[float]): Query embedding vector
            top_k (int): Number of results to return
            
        Returns:
            List[Dict]: List of results with documents, scores, and metadata
        """
        if not self.vectors or not self.documents:
            return []
        
        # Convert query vector to numpy array
        query_array = np.array(query_vector)
        
        # Calculate similarities
        similarities = []

        for i, vector in enumerate(self.vectors):
            if vector is not None:

                similarity = np.dot(query_array, vector) / (
                    np.linalg.norm(query_array) * np.linalg.norm(vector)
                )
                similarities.append((i, similarity))
        
        # Sort by similarity (descending)

        similarities.sort(key=lambda x: x[1], reverse=True)
        
        # Get top-k results
        results = []
        for i, score in similarities[:top_k]:
            results.append({
                "document": self.documents[i],
                "score": float(score),
                "metadata": self.metadata[i]
            })
        
        return results

    def similarity_search(self, query_embedding, k=5):
        """
        查找与查询嵌入向量最相似的文本。

        参数:
        query_embedding (List[float]): 查询的嵌入向量。
        k (int, optional): 返回最相似的k个结果。

        返回:
        List[Dict]: 最相似的文本及其相关信息。
        """
        if not self.vectors:
            return []
        

        query_vector = np.array(query_embedding)
        

        similarities = []

        for i, vector in enumerate(self.vectors):

            similarity = np.dot(query_vector, vector) / (np.linalg.norm(query_vector) * np.linalg.norm(vector))
            similarities.append((i, similarity))
        

        similarities.sort(key=lambda x: x[1], reverse=True)
        

        results = []
        for i in range(min(k, len(similarities))):
            idx, score = similarities[i]
            results.append({
                "text": self.texts[idx],
                "metadata": self.metadata[idx],
                "similarity": score
            })
        
        return results

生成向量

In [86]:
def create_embeddings_in_batches(text_chunks, model="text-embedding-v3", batch_size_limit=10): # 我改成了官方模型名，你可以换回 "text-embedding-v3"
    """
    调用 OpenAI 的 Embedding API 来创建文本列表的嵌入向量，处理批处理大小限制。

    参数:
    text_chunks (List[str]): 需要创建嵌入的文本字符串列表。
    model (str): 使用的嵌入模型。
    batch_size_limit (int): API 允许的最大批处理大小。根据错误信息，这里是10。

    返回:
    List[List[float]]: 所有文本的嵌入向量列表。
    """
    all_embeddings = []
    if not text_chunks:
        return []

    if not isinstance(text_chunks, list): # 确保输入是列表
        text_chunks = [text_chunks]

    for i in range(0, len(text_chunks), batch_size_limit):
        batch = text_chunks[i:i + batch_size_limit]
        try:
            #print(f"Processing batch {i//batch_size_limit + 1}, size: {len(batch)}")
            response = client.embeddings.create(
                input=batch,
                model=model,
                encoding_format="float"
            )
            # 从响应中提取该批次的嵌入向量
            batch_embeddings = [item.embedding for item in response.data]
            all_embeddings.extend(batch_embeddings)


        except Exception as e:
            print(f"Error processing batch starting with chunk: '{batch[0][:50]}...'")
            print(f"API Error: {e}")

            raise e 

    return all_embeddings

def create_embeddings(text, model="text-embedding-v3"):
    """
    字符串向量化
    参数:
    text (str): 需要创建嵌入的文本字符串。
    model (str): 使用的嵌入模型。

    返回:
    List[float]: 文本的嵌入向量。
    """
    response = client.embeddings.create(
        model=model,
        input=text
    )

    return response.data[0].embedding

处理文本

In [87]:
def process_document(pdf_path, chunk_size=800):
    
    print("Extracting text from document...")

    text = extract_text_from_pdf(pdf_path)
    
    print("Chunking text into non-overlapping segments...")

    chunks = chunk_text(text, chunk_size, 0)
    print(f"Created {len(chunks)} chunks")
    
    print("Generating embeddings for chunks...")

    chunk_embeddings = create_embeddings_in_batches(chunks)
    
    vector_store = SimpleVectorStore()
    
    metadata = [{"chunk_index": i, "source": pdf_path} for i in range(len(chunks))]
    vector_store.add_documents(chunks, chunk_embeddings, metadata)
    

    doc_info = {
        "chunks": chunks,
        "source": pdf_path,
    }
    
    return chunks, vector_store, doc_info

RSE核心方法：计算分块值和查找最佳段

In [88]:
def calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty=0.2):
    """
    通过相关性和位置来计算块值。
    参数：
        query (str): 查询文本
        chunks (List[str]): 文档块列表
        vector_store (SimpleVectorStore): 包含块的向量存储
        irrelevant_chunk_penalty (float): 无关块的惩罚
    返回：
        List[float]: 块值列表
    """
    query_embedding = create_embeddings(query)
    num_chunks = len(chunks)
    results = vector_store.search(query_embedding, num_chunks)
    
    relevance_scores = {result["metadata"]["chunk_index"]: result["score"] for result in results}
    
    chunk_values = []
    for i in range(num_chunks):
        score = relevance_scores.get(i, 0.0)
        value = score - irrelevant_chunk_penalty
        chunk_values.append(value)
    
    return chunk_values

In [89]:
def find_best_segments(chunk_values, max_segment_length=20, total_max_length=30, min_segment_value=0.2):
    """
    使用最大子数组和的变体算法，找到最优的连续文本段,贪心算法。

    参数：
        chunk_values (List[float]): 每个块的值的列表
        max_segment_length (int): 单个段的最大长度
        total_max_length (int): 所有段的最大总长度
        min_segment_value (float): 段的最小取值，以考虑

    返回：
        List[Tuple[int, int]]: 最优段的 (start, end) 索引列表
    """
    print("Finding optimal continuous text segments...")
    
    best_segments = []
    segment_scores = []
    total_included_chunks = 0
    
    while total_included_chunks < total_max_length:
        best_score = min_segment_value  
        best_segment = None

        for start in range(len(chunk_values)):

            if any(start >= s[0] and start < s[1] for s in best_segments):
                continue
                
            for length in range(1, min(max_segment_length, len(chunk_values) - start) + 1):
                end = start + length
                
                if any(end > s[0] and end <= s[1] for s in best_segments):
                    continue
                
                segment_value = sum(chunk_values[start:end])
                
                if segment_value > best_score:
                    best_score = segment_value
                    best_segment = (start, end)
        
        if best_segment:
            best_segments.append(best_segment)
            segment_scores.append(best_score)
            total_included_chunks += best_segment[1] - best_segment[0]
            print(f"Found segment {best_segment} with score {best_score:.4f}")
        else:
            break
    
    best_segments = sorted(best_segments, key=lambda x: x[0])
    
    return best_segments, segment_scores

重建段落并使用段落来做RAG

In [90]:
def reconstruct_segments(chunks, best_segments):
    """
    重新构造文本段，基于块索引。
    参数：
        chunks (List[str]): 所有文档块的列表
        best_segments (List[Tuple[int, int]]): 段的 (start, end) 索引列表
    返回：
        List[str]: 重新构造的文本段列表
    """
    reconstructed_segments = []  
    for start, end in best_segments:
        
        segment_text = " ".join(chunks[start:end])
        
        reconstructed_segments.append({
            "text": segment_text,
            "segment_range": (start, end),
        })
    
    return reconstructed_segments  

In [91]:
def format_segments_for_context(segments):
    """
    格式化segments为LLM的上下文字符串。

    参数：
    segments (List[Dict]): 包含segment字典的列表

    返回：
    str: 格式化的上下文文本   
    """
    context = []  
    
    for i, segment in enumerate(segments):
        
        segment_header = f"SEGMENT {i+1} (Chunks {segment['segment_range'][0]}-{segment['segment_range'][1]-1}):"
        context.append(segment_header)  
        context.append(segment['text'])  
        context.append("-" * 80)  
    
    return "\n\n".join(context)

生成答案

In [None]:
def generate_response(query, context, model="qwen3-4b"):
    print("Generating response using relevant segments as context...")

    system_prompt = """You are a helpful assistant that answers questions based on the provided context.
    The context consists of document segments that have been retrieved as relevant to the user's query.
    Use the information from these segments to provide a comprehensive and accurate answer.
    If the context doesn't contain relevant information to answer the question, say so clearly."""
    
    user_prompt = f"""
        Context:{context}
        Question: {query}
        Please provide a helpful answer based on the context provided.
        """
    

    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        
        extra_body={
            "enable_thinking": False,
            "temperature": 0
            }
    )
    
    return response.choices[0].message.content

完整RSE流程

In [93]:
def rag_with_rse(pdf_path, query, chunk_size=800, irrelevant_chunk_penalty=0.2):
    """
    完整的 RAG 流水线，包括相关片段提取。
    参数：
        pdf_path (str): 文档的路径
        query (str): 用户查询
        chunk_size (int): 分块的大小
        irrelevant_chunk_penalty (float): 与无关分块相关的惩罚
    返回：
        Dict: 包含查询、片段和响应的结果
    """
    print("\n=== STARTING RAG WITH RELEVANT SEGMENT EXTRACTION ===")
    print(f"Query: {query}")
    
    chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)
    
    print("\nCalculating relevance scores and chunk values...")
    chunk_values = calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty)
    
    best_segments, scores = find_best_segments(
        chunk_values, 
        max_segment_length=20, 
        total_max_length=30, 
        min_segment_value=0.2
    )
    
    print("\nReconstructing text segments from chunks...")
    segments = reconstruct_segments(chunks, best_segments)
    
    context = format_segments_for_context(segments)
    
    response = generate_response(query, context)
    
    result = {
        "query": query,
        "segments": segments,
        "response": response
    }
    
    print("\n=== FINAL RESPONSE ===")
    print(response)
    
    return result

对比标准的检索

In [94]:
def standard_top_k_retrieval(pdf_path, query, k=10, chunk_size=800):

    print("\n=== STARTING STANDARD TOP-K RETRIEVAL ===")
    print(f"Query: {query}")
    
    chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)

    print("Creating query embedding and retrieving chunks...")
    query_embedding = create_embeddings(query)

    results = vector_store.search(query_embedding, top_k=k)
    retrieved_chunks = [result["document"] for result in results]

    context = "\n\n".join([
        f"CHUNK {i+1}:\n{chunk}" 
        for i, chunk in enumerate(retrieved_chunks)
    ])

    response = generate_response(query, context)

    result = {
        "query": query,
        "chunks": retrieved_chunks,
        "response": response
    }
    
    print("\n=== FINAL RESPONSE ===")
    print(response)
    
    return result

评估

In [95]:
def evaluate_methods(pdf_path, query, reference_answer=None):

    print("\n========= EVALUATION =========\n")
    
    rse_result = rag_with_rse(pdf_path, query)

    standard_result = standard_top_k_retrieval(pdf_path, query)

    if reference_answer:
        print("\n=== COMPARING RESULTS ===")

        evaluation_prompt = f"""
            Query: {query}

            Reference Answer:
            {reference_answer}

            Response from Standard Retrieval:
            {standard_result["response"]}

            Response from Relevant Segment Extraction:
            {rse_result["response"]}

            Compare these two responses against the reference answer. Which one is:
            1. More accurate and comprehensive
            2. Better at addressing the user's query
            3. Less likely to include irrelevant information

            Explain your reasoning for each point.
        """
        
        print("Evaluating responses against reference answer...")

        evaluation = client.chat.completions.create(
            model="qwen-plus",
            messages=[
                {"role": "system", "content": "You are an objective evaluator of RAG system responses."},
                {"role": "user", "content": evaluation_prompt}
            ]
        )

        print("\n=== EVALUATION RESULTS ===")
        print(evaluation.choices[0].message.content)

    return {
        "rse_result": rse_result,
        "standard_result": standard_result
    }

In [96]:

with open('data/val.json') as f:
    data = json.load(f)

query = data[0]['question']

reference_answer = data[0]['ideal_answer']

pdf_path = "data/AI_Information.pdf"

results = evaluate_methods(pdf_path, query, reference_answer)




=== STARTING RAG WITH RELEVANT SEGMENT EXTRACTION ===
Query: What is 'Explainable AI' and why is it considered important?
Extracting text from document...
Chunking text into non-overlapping segments...
Created 42 chunks
Generating embeddings for chunks...

Calculating relevance scores and chunk values...
Finding optimal continuous text segments...
Found segment (22, 42) with score 6.7414
Found segment (0, 20) with score 6.3919

Reconstructing text segments from chunks...
Generating response using relevant segments as context...

=== FINAL RESPONSE ===
Explainable AI (XAI) refers to the development of AI systems that are more transparent and understandable. The goal of XAI is to provide insights into how AI models make decisions, which enhances trust and accountability. This is important because many AI systems, particularly deep learning models, are often considered "black boxes," meaning it is difficult to understand how they arrive at their decisions. 

By making AI systems more e