In [None]:
print("Hello, world!")

#  运行具有“.venv (Python 3.12.4)”的单元格需要ipykernel包。
# 运行以下命令，将 \"ipykernel\" 安装到 Python 环境中。
# 命令: \"/Users/davirian/dev/claude-example/.venv/bin/python -m pip install ipykernel -U --force-reinstall\""

# 安装ipykernel
#uv pip install ipykernel

# 检索增强生成（Retrieval Augmented Generation，简称RAG）


Claude（一种AI模型）在许多任务中表现出色，但在处理特定于您独特业务环境的查询时可能会遇到困难。这就是RAG变得非常有价值的原因。RAG使Claude能够利用您的内部知识库或客户支持文档，显著提高其回答特定领域问题的能力。越来越多的企业正在构建RAG应用程序，以改善客户支持、内部公司文档问答、财务和法律分析等方面的工作流程。

在本指南中，我们将演示如何使用Anthropic文档作为知识库来构建和优化RAG系统。我们将引导您完成以下步骤：

1. 使用内存向量数据库和Voyage AI的嵌入来设置基本的RAG系统。
2. 建立一个健全的评估套件。我们将超越基于"直觉"的评估，向您展示如何独立测量检索管道和端到端性能。
3. 实施先进技术来改进RAG，包括摘要索引和使用Claude进行重新排序。

通过一系列有针对性的改进，与基本RAG管道相比，我们在以下指标上取得了显著的性能提升（稍后我们会解释这些指标的含义）：

平均精确度：从0.43提升到0.46
平均召回率：从0.66提升到0.74
平均F1分数：从0.52提升到0.57
平均平均倒数排名（MRR）：从0.74提升到0.93
端到端准确率：从70%提升到83%


>**注意事项：**
>本指南中的评估旨在模拟生产环境中的评估系统，您应该注意到运行这些评估可能需要一些时间。另外值得注意的是：如果您完整运行评估，除非您处于第二级或更高级别，否则可能会遇到速率限制。如果您想节省令牌使用量，可以考虑跳过完整的端到端评估。

## 目录

1. 设置
2. 基本 RAG - 第 1 级
3. 建立评估系统
4. 二级 - 摘要索引
5. 3 级 - 摘要索引和重新排序

## 设置

我们需要一些库,包括:

1. anthropic - 与 Claude 交互
2. voyageai - 生成高质量嵌入
3. pandas, numpy, matplotlib, and scikit-learn 用于数据处理和可视化

您还需要从[Anthropic](https://www.anthropic.com/)和[Voyage AI](https://www.voyageai.com/)获取 API 密钥。




In [None]:
## setup
!uv pip install anthropic
!uv pip install voyageai
!uv pip install pandas
!uv pip install numpy
!uv pip install matplotlib
!uv pip install seaborn
!uv pip install -U scikit-learn
!uv pip install python-dotenv

In [None]:
import os
import dotenv

dotenv.load_dotenv()

os.environ['VOYAGE_API_KEY'] = os.getenv("VOYAGE_API_KEY")
os.environ['ANTHROPIC_API_KEY'] = os.getenv("ANTHROPIC_API_KEY")

print("Voyage API Key:", os.environ['VOYAGE_API_KEY'])
print("Anthropic API Key:", os.environ['ANTHROPIC_API_KEY'])

In [None]:
import anthropic
import os

client = anthropic.Anthropic(
    # This is the default and can be omitted
    api_key=os.getenv("ANTHROPIC_API_KEY"),
)

# 初始化矢量数据库类

在此示例中,我们正在使用内存向量数据库,但是对于生产应用程序,您可能需要使用托管解决方案。

In [None]:
import os
import pickle
import json
import numpy as np
import voyageai

class VectorDB:
    def __init__(self, name, api_key=None):
        if api_key is None:
            api_key = os.getenv("VOYAGE_API_KEY")
        self.client = voyageai.Client(api_key=api_key)
        self.name = name
        self.embeddings = []
        self.metadata = []
        self.query_cache = {}
        self.db_path = f"./data/{name}/vector_db.pkl"

    def load_data(self, data):
        if self.embeddings and self.metadata:
            print("Vector database is already loaded. Skipping data loading.")
            return
        if os.path.exists(self.db_path):
            print("Loading vector database from disk.")
            self.load_db()
            return

        texts = [f"Heading: {item['chunk_heading']}\n\n Chunk Text:{item['text']}" for item in data]
        self._embed_and_store(texts, data)
        self.save_db()
        print("Vector database loaded and saved.")

    def _embed_and_store(self, texts, data):
        batch_size = 128
        result = [
            self.client.embed(
                texts[i : i + batch_size],
                model="voyage-2"
            ).embeddings
            for i in range(0, len(texts), batch_size)
        ]
        self.embeddings = [embedding for batch in result for embedding in batch]
        self.metadata = data

    def search(self, query, k=5, similarity_threshold=0.75):
        if query in self.query_cache:
            query_embedding = self.query_cache[query]
        else:
            query_embedding = self.client.embed([query], model="voyage-2").embeddings[0]
            self.query_cache[query] = query_embedding

        if not self.embeddings:
            raise ValueError("No data loaded in the vector database.")

        similarities = np.dot(self.embeddings, query_embedding)
        top_indices = np.argsort(similarities)[::-1]
        top_examples = []
        
        for idx in top_indices:
            if similarities[idx] >= similarity_threshold:
                example = {
                    "metadata": self.metadata[idx],
                    "similarity": similarities[idx],
                }
                top_examples.append(example)
                
                if len(top_examples) >= k:
                    break
        self.save_db()
        return top_examples

    def save_db(self):
        data = {
            "embeddings": self.embeddings,
            "metadata": self.metadata,
            "query_cache": json.dumps(self.query_cache),
        }
        os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
        with open(self.db_path, "wb") as file:
            pickle.dump(data, file)

    def load_db(self):
        if not os.path.exists(self.db_path):
            raise ValueError("Vector database file not found. Use load_data to create a new database.")
        with open(self.db_path, "rb") as file:
            data = pickle.load(file)
        self.embeddings = data["embeddings"]
        self.metadata = data["metadata"]
        self.query_cache = json.loads(data["query_cache"])

## 基本 RAG - 第 1 级

要开始,我们将使用简单的方法建立一个基本的 RAG 管道。这在业内有时被称为"朴素 RAG"。一个基本的 RAG 管道包括以下 3 个步骤:

1. 按标题分块文档 - 仅包含每个子标题下的内容
2. 嵌入每个文档
3. 使用余弦相似度来检索文档以回答查询



In [None]:
import json
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from tqdm import tqdm
import logging
from typing import Callable, List, Dict, Any, Tuple, Set

# Load the evaluation dataset
with open('evaluation/docs_evaluation_dataset.json', 'r') as f:
    eval_data = json.load(f)

# Load the Anthropic documentation
with open('data/anthropic_docs.json', 'r') as f:
    anthropic_docs = json.load(f)

# Initialize the VectorDB
db = VectorDB("anthropic_docs")
db.load_data(anthropic_docs)

def retrieve_base(query, db):
    results = db.search(query, k=3)
    context = ""
    for result in results:
        chunk = result['metadata']
        context += f"\n{chunk['text']}\n"
    return results, context

def answer_query_base(query, db):
    documents, context = retrieve_base(query, db)
    prompt = f"""
    You have been tasked with helping us to answer the following query: 
    <query>
    {query}
    </query>
    You have access to the following documents which are meant to provide context as you answer the query:
    <documents>
    {context}
    </documents>
    Please remain faithful to the underlying context, and only deviate from it if you are 100% sure that you know the answer already. 
    Answer the question now, and avoid providing preamble such as 'Here is the answer', etc
    """
    response = client.messages.create(
        model="claude-3-haiku-20240307",
        max_tokens=2500,
        messages=[
            {"role": "user", "content": prompt}
        ],
        temperature=0
    )
    return response.content[0].text

## 评估设置

在评估基于检索增强的生成模型(RAG)应用程序时,分别评估检索系统和端到端系统的性能非常关键。

我们合成生成了一个由 100 个样本组成的评估数据集,其中包括以下内容:

- 一个问题
- 与此问题相关的我们文档中的部分内容。这就是我们期望检索系统在被问到这个问题时检索到的内容。
- 正确答案。

这是一个较有挑战性的数据集。我们的一些问题需要在多个片段间进行综合分析才能得到正确答案,因此我们的系统需要同时加载多个片段。你可以通过打开 evaluation/docs_evaluation_dataset.json 来查看数据集。

运行下一个单元格以预览数据集


In [None]:
#previewing our eval dataset
import json

def preview_json(file_path, num_items=3):
    try:
        with open(file_path, 'r') as file:
            data = json.load(file)
            
        if isinstance(data, list):
            preview_data = data[:num_items]
        elif isinstance(data, dict):
            preview_data = dict(list(data.items())[:num_items])
        else:
            print(f"Unexpected data type: {type(data)}. Cannot preview.")
            return
        
        print(f"Preview of the first {num_items} items from {file_path}:")
        print(json.dumps(preview_data, indent=2))
        print(f"\nTotal number of items: {len(data)}")
        
    except FileNotFoundError:
        print(f"File not found: {file_path}")
    except json.JSONDecodeError:
        print(f"Invalid JSON in file: {file_path}")
    except Exception as e:
        print(f"An error occurred: {str(e)}")

preview_json('evaluation/docs_evaluation_dataset.json')

## Metric Definitions 

我们将根据 5 个关键指标来评估我们的系统:精确度、召回率、F1 分数、平均倒数排名(MRR)和端到端准确性。

### 检索指标:

#### 精确

精确性代表被检索到的块中实际相关的比例。它回答了这个问题："在我们检索到的块中,有多少是正确的?"


关键要点:

- 高精度表示一个高效的系统,很少出现误报。
- 低精度表示正在检索到许多无关的块。
- 我们的系统每次查询至少检索 3 个片段,这可能会影响精确度评分。

#### Recall

召回率衡量了我们检索系统的完整性。它回答了这个问题:"在所有正确存在的块中,我们成功检索了多少个?"

关键要点:

- 高回召意味着对必要信息的全面覆盖。
- 低召回率表明遗漏了重要部分。
- 确保LLM能够访问所有所需信息的关键在于回忆。

#### F1 分数

F1 分数提供了精确度和召回率之间的平衡度量。当需要一个单一的指标来评估系统性能时,尤其是在不平衡的类别分布情况下,它特别有用。

关键要点:

- F1 得分范围从 0 到 1,其中 1 代表完美的精确度和召回率。
- 它是精确度和召回率的谐波平均值，趋向于两个值中较低的那个。
- 在假阳性和假阴性都很重要的场景中很有用。

解释 F1 分数:

- F1 分数 1.0 表示完美的精确度和召回率。
- An F1 score of 0.0 indicates the worst performance.
- 总的来说,F1 分数越高,整体表现越好。

### 平衡精确率、召回率和 F1 得分

- 精确度和召回率之间通常存在权衡。
- 我们的系统最小块检索更偏重于召回率而非精确率。
- 最优平衡取决于具体用例。
- 在许多 RAG 系统中,高召回率通常是优先考虑的,因为LLMs可以在生成过程中过滤掉不太相关的信息。

## 平均倒数排名（MRR）@k

MRR 衡量我们的系统对相关信息的排名效果。它帮助我们了解用户从检索结果的最顶端开始,需要多快才能找到他们想要的内容。

关键要点:

- MRR 的范围从 0 到 1,其中 1 是完美的(正确答案总是第一位的)。
- 它只考虑每个查询中第一个正确结果的排名。
- 较高的最小相关度表示相关信息的排序更好。

在哪里：

- |Q|是查询的总数
- rank_i 是第 i 个查询的第一个相关项目的位置

## 端到端指标:

###  端到端准确性

我们使用LLM作为评判者（Claude 3.5 Sonnet）来评估生成的答案是否正确,这是基于问题和真实答案。

这个指标评估了从检索到答案生成的整个管道。

## 定义我们的度量计算函数

In [None]:
def calculate_mrr(retrieved_links: List[str], correct_links: Set[str]) -> float:
    for i, link in enumerate(retrieved_links, 1):
        if link in correct_links:
            return 1 / i
    return 0

def evaluate_retrieval(retrieval_function: Callable, evaluation_data: List[Dict[str, Any]], db: Any) -> Tuple[float, float, float, float, List[float], List[float], List[float]]:
    precisions = []
    recalls = []
    mrrs = []
    
    for i, item in enumerate(tqdm(evaluation_data, desc="Evaluating Retrieval")):
        try:
            retrieved_chunks, _ = retrieval_function(item['question'], db)
            retrieved_links = [chunk['metadata'].get('chunk_link', chunk['metadata'].get('url', '')) for chunk in retrieved_chunks]
        except Exception as e:
            logging.error(f"Error in retrieval function: {e}")
            continue

        correct_links = set(item['correct_chunks'])
        
        true_positives = len(set(retrieved_links) & correct_links)
        precision = true_positives / len(retrieved_links) if retrieved_links else 0
        recall = true_positives / len(correct_links) if correct_links else 0
        mrr = calculate_mrr(retrieved_links, correct_links)
        
        precisions.append(precision)
        recalls.append(recall)
        mrrs.append(mrr)
        
        if (i + 1) % 10 == 0:
            print(f"Processed {i + 1}/{len(evaluation_data)} items. Current Avg Precision: {sum(precisions) / len(precisions):.4f}, Avg Recall: {sum(recalls) / len(recalls):.4f}, Avg MRR: {sum(mrrs) / len(mrrs):.4f}")
    
    avg_precision = sum(precisions) / len(precisions) if precisions else 0
    avg_recall = sum(recalls) / len(recalls) if recalls else 0
    avg_mrr = sum(mrrs) / len(mrrs) if mrrs else 0
    f1 = 2 * (avg_precision * avg_recall) / (avg_precision + avg_recall) if (avg_precision + avg_recall) > 0 else 0
    
    return avg_precision, avg_recall, avg_mrr, f1, precisions, recalls, mrrs

def evaluate_end_to_end(answer_query_function, db, eval_data):
    correct_answers = 0
    results = []
    total_questions = len(eval_data)
    
    for i, item in enumerate(tqdm(eval_data, desc="Evaluating End-to-End")):
        query = item['question']
        correct_answer = item['correct_answer']
        generated_answer = answer_query_function(query, db)
        
        prompt = f"""
        You are an AI assistant tasked with evaluating the correctness of answers to questions about Anthropic's documentation.
        
        Question: {query}
        
        Correct Answer: {correct_answer}
        
        Generated Answer: {generated_answer}
        
        Is the Generated Answer correct based on the Correct Answer? You should pay attention to the substance of the answer, and ignore minute details that may differ. 
        
        Small differences or changes in wording don't matter. If the generated answer and correct answer are saying essentially the same thing then that generated answer should be marked correct. 
        
        However, if there is any critical piece of information which is missing from the generated answer in comparison to the correct answer, then we should mark this as incorrect. 
        
        Finally, if there are any direct contradictions between the correect answer and generated answer, we should deem the generated answer to be incorrect.
        
        Respond in the following XML format:
        <evaluation>
        <content>
        <explanation>Your explanation here</explanation>
        <is_correct>true/false</is_correct>
        </content>
        </evaluation>
        """
        
        try:
            response = client.messages.create(
                model="claude-3-5-sonnet-20240620",
                max_tokens=1500,
                messages=[
                    {"role": "user", "content": prompt},
                    {"role": "assistant", "content": "<evaluation>"}
                ],
                temperature=0,
                stop_sequences=["</evaluation>"]
            )
            
            response_text = response.content[0].text
            print(response_text)
            evaluation = ET.fromstring(response_text)
            is_correct = evaluation.find('is_correct').text.lower() == 'true'
            
            if is_correct:
                correct_answers += 1
            results.append(is_correct)
            
            logging.info(f"Question {i + 1}/{total_questions}: {query}")
            logging.info(f"Correct: {is_correct}")
            logging.info("---")
            
        except ET.ParseError as e:
            logging.error(f"XML parsing error: {e}")
            is_correct = 'true' in response_text.lower()
            results.append(is_correct)
        except Exception as e:
            logging.error(f"Unexpected error: {e}")
            results.append(False)
        
        if (i + 1) % 10 == 0:
            current_accuracy = correct_answers / (i + 1)
            print(f"Processed {i + 1}/{total_questions} questions. Current Accuracy: {current_accuracy:.4f}")
        # time.sleep(2)
    accuracy = correct_answers / total_questions
    return accuracy, results

## 用于绘制性能的辅助函数

In [None]:
import os
import json
import matplotlib.pyplot as plt
import seaborn as sns

def plot_performance(results_folder='evaluation/json_results', include_methods=None, colors=None):
    # Set default colors
    default_colors = ['skyblue', 'lightgreen', 'salmon']
    if colors is None:
        colors = default_colors
    
    # Load JSON files
    results = []
    for filename in os.listdir(results_folder):
        if filename.endswith('.json'):
            file_path = os.path.join(results_folder, filename)
            with open(file_path, 'r') as f:
                try:
                    data = json.load(f)
                    if 'name' not in data:
                        print(f"Warning: {filename} does not contain a 'name' field. Skipping.")
                        continue
                    if include_methods is None or data['name'] in include_methods:
                        results.append(data)
                except json.JSONDecodeError:
                    print(f"Warning: {filename} is not a valid JSON file. Skipping.")
    
    if not results:
        print("No JSON files found with matching 'name' fields.")
        return
    
    # Validate data
    required_metrics = ["average_precision", "average_recall", "average_f1", "average_mrr", "end_to_end_accuracy"]
    for result in results.copy():
        if not all(metric in result for metric in required_metrics):
            print(f"Warning: {result['name']} is missing some required metrics. Skipping.")
            results.remove(result)
    
    if not results:
        print("No valid results remaining after validation.")
        return
    
    # Sort results based on end-to-end accuracy
    results.sort(key=lambda x: x['end_to_end_accuracy'])
    
    # Prepare data for plotting
    methods = [result['name'] for result in results]
    metrics = required_metrics
    
    # Set up the plot
    plt.figure(figsize=(14, 6))
    sns.set_style("whitegrid")
    
    x = range(len(metrics))
    width = 0.8 / len(results)
    
    # Create color palette
    num_methods = len(results)
    color_palette = colors[:num_methods] + sns.color_palette("husl", num_methods - len(colors))
    
    # Plot bars for each method
    for i, (result, color) in enumerate(zip(results, color_palette)):
        values = [result[metric] for metric in metrics]
        offset = (i - len(results)/2 + 0.5) * width
        bars = plt.bar([xi + offset for xi in x], values, width, label=result['name'], color=color)
        
        # Add value labels on the bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                     f'{height:.2f}', ha='center', va='bottom', fontsize=8)
    
    # Customize the plot
    plt.xlabel('Metrics', fontsize=12)
    plt.ylabel('Values', fontsize=12)
    plt.title('RAG Performance Metrics (Sorted by End-to-End Accuracy)', fontsize=16)
    plt.xticks(x, metrics, rotation=45, ha='right')
    plt.legend(title='Methods', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.ylim(0, 1)
    
    plt.tight_layout()
    plt.show()

## 评估我们的基本情况

In [None]:
import pandas as pd

avg_precision, avg_recall, avg_mrr, f1, precisions, recalls, mrrs = evaluate_retrieval(retrieve_base, eval_data, db)
e2e_accuracy, e2e_results = evaluate_end_to_end(answer_query_base, db, eval_data)

# Create a DataFrame
df = pd.DataFrame({
    'question': [item['question'] for item in eval_data],
    'retrieval_precision': precisions,
    'retrieval_recall': recalls,
    'retrieval_mrr': mrrs,
    'e2e_correct': e2e_results
})

# Save to CSV
df.to_csv('evaluation/csvs/evaluation_results_detailed.csv', index=False)
print("Detailed results saved to evaluation/csvs/evaluation_results_one.csv")

# Print the results
print(f"Average Precision: {avg_precision:.4f}")
print(f"Average Recall: {avg_recall:.4f}")
print(f"Average MRR: {avg_mrr:.4f}")
print(f"Average F1: {f1:.4f}")
print(f"End-to-End Accuracy: {e2e_accuracy:.4f}")

# Save the results to a file
with open('evaluation/json_results/evaluation_results_one.json', 'w') as f:
    json.dump({
        "name": "Basic RAG",
        "average_precision": avg_precision,
        "average_recall": avg_recall,
        "average_f1": f1,
        "average_mrr": avg_mrr,
        "end_to_end_accuracy": e2e_accuracy
    }, f, indent=2)

print("Evaluation complete. Results saved to evaluation_results_one.json, evaluation_results_one.csv")

In [None]:
#let's visualize our performance
plot_performance('evaluation/json_results', ['Basic RAG'], colors=['skyblue'])