# RAG框架作业实现
## 集成私有Embedding模型与Chroma向量数据库

## 1. 环境准备与依赖安装

在开始之前，请确保已安装所有必要的依赖项。您可以使用以下命令创建虚拟环境并安装依赖：

```bash
conda create -n rag_week02 python=3.9
conda activate rag_week02
pip install -r requirements.txt
```

In [None]:
!pip install -r requirements.txt

## 2. 核心模块定义

In [None]:
import os
import requests
from openai import OpenAI
from tqdm import tqdm
import chromadb

class CustomEmbedding:
    def __init__(self, api_key, base_url, model_name):
        self.api_key = api_key
        self.base_url = base_url
        self.model_name = model_name
        self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
    
    def get_embeddings(self, texts: list[str]) -> list[list[float]]:
        try:
            response = self.client.embeddings.create(
                input=texts,
                model=self.model_name
            )
            embeddings = [item.embedding for item in response.data]
            return embeddings
        except Exception as e:
            print(f"Error getting embeddings: {e}")
            return []

In [None]:
class ChromaDBManager:
    def __init__(self, collection_name: str, embedding_function):
        self.client = chromadb.Client()
        self.collection_name = collection_name
        self.collection = self.client.get_or_create_collection(
            name=collection_name,
            embedding_function=embedding_function
        )
    
    def add_documents(self, documents: list[str], metadatas: list[dict], ids: list[str]):
        self.collection.add(
            documents=documents,
            metadatas=metadatas,
            ids=ids
        )
    
    def query(self, query_text: str, n_results: int = 5) -> list[dict]:
        results = self.collection.query(
            query_texts=[query_text],
            n_results=n_results
        )
        
        formatted_results = []
        for i in range(len(results['ids'][0])):
            formatted_results.append({
                "document": results['documents'][0][i],
                "metadata": results['metadatas'][0][i],
                "distance": results['distances'][0][i]
            })
        
        return formatted_results

In [None]:
class CustomReranker:
    def __init__(self, api_key, base_url, model_name):
        self.api_key = api_key
        self.base_url = base_url
        self.model_name = model_name
    
    def rerank(self, query: str, documents: list[str], top_n: int = 3) -> list[dict]:
        try:
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json"
            }
            
            payload = {
                "model": self.model_name,
                "query": query,
                "passages": documents
            }
            
            response = requests.post(
                f"{self.base_url}/rerank",
                headers=headers,
                json=payload
            )
            
            if response.status_code == 200:
                results = response.json()["results"]
                # Sort by relevance score and take top_n
                sorted_results = sorted(results, key=lambda x: x["relevance_score"], reverse=True)[:top_n]
                return [{"document": documents[item["index"]], "relevance_score": item["relevance_score"]} 
                        for item in sorted_results]
            else:
                print(f"Rerank API request failed with status code: {response.status_code}")
                return []
        except Exception as e:
            print(f"Error during reranking: {e}")
            return []

In [None]:
class CustomLLM:
    def __init__(self, api_key, base_url, model_name):
        self.api_key = api_key
        self.base_url = base_url
        self.model_name = model_name
        self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
    
    def generate(self, prompt: str) -> str:
        try:
            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[{"role": "user", "content": prompt}]
            )
            return response.choices[0].message.content
        except Exception as e:
            print(f"Error generating response: {e}")
            return ""

## 3. RAG框架整合

In [None]:
import os
import uuid

class SimpleRAG:
    def __init__(self):
        # API配置
        self.embedding_api_key = os.getenv('EMBEDDING_API_KEY', '')
        self.embedding_base_url = os.getenv('EMBEDDING_BASE_URL', '/api/inference/v1')
        self.embedding_model_name = os.getenv('EMBEDDING_MODEL_NAME', 'bge-large-zh-v1.5')
        
        self.rerank_api_key = os.getenv('RERANKER_API_KEY', '')
        self.rerank_base_url = os.getenv('RERANKER_BASE_URL', '/api/inference/v1')
        self.rerank_model_name = os.getenv('RERANKER_MODEL_NAME', 'bge-reranker-v2-m3')
        
        self.llm_api_key = os.getenv('LLM_API_KEY', '')
        self.llm_base_url = os.getenv('LLM_BASE_URL', '/api/inference/v1')
        self.llm_model_name = os.getenv('LLM_MODEL_NAME', 'GLM-4.6-FP8')
        
        # 初始化模块
        self.embedding_client = CustomEmbedding(
            self.embedding_api_key, 
            self.embedding_base_url, 
            self.embedding_model_name
        )
        
        # 注意：ChromaDBManager需要一个embedding函数，这里我们传递get_embeddings方法
        self.db_manager = ChromaDBManager(
            "rag_collection", 
            self.embedding_client.get_embeddings
        )
        
        self.reranker = CustomReranker(
            self.rerank_api_key, 
            self.rerank_base_url, 
            self.rerank_model_name
        )
        
        self.llm_client = CustomLLM(
            self.llm_api_key, 
            self.llm_base_url, 
            self.llm_model_name
        )
    
    def ingest(self, documents: list[str], metadatas: list[dict] = None, ids: list[str] = None):
        # 自动生成metadatas和ids（如果未提供）
        if metadatas is None:
            metadatas = [{"source": "default"} for _ in documents]
        
        if ids is None:
            ids = [str(uuid.uuid4()) for _ in documents]
        
        # 获取文档的embedding
        embeddings = self.embedding_client.get_embeddings(documents)
        
        # 将文档和其向量存入Chroma
        self.db_manager.add_documents(documents, metadatas, ids)
    
    def query(self, question: str, use_rerank: bool = True) -> dict:
        # 步骤1：获取问题的embedding
        question_embedding = self.embedding_client.get_embeddings(<question>)
        
        # 步骤2：检索相关文档
        retrieved_docs = self.db_manager.query(question, n_results=5)
        
        # 步骤3：如果use_rerank为True，则进行重排序
        if use_rerank and retrieved_docs:
            doc_texts = [doc["document"] for doc in retrieved_docs]
            reranked_docs = self.reranker.rerank(question, doc_texts, top_n=3)
            context = "\n".join([doc["document"] for doc in reranked_docs])
        else:
            context = "\n".join([doc["document"] for doc in retrieved_docs[:3]])
        
        # 步骤4：构建提示词模板
        prompt = f"""基于以下上下文回答问题：

上下文：
{context}

问题：{question}

请根据上下文回答问题。如果上下文中没有相关信息，请说明无法基于提供的上下文回答问题。"""
        
        # 步骤5：生成最终答案
        answer = self.llm_client.generate(prompt)
        
        # 步骤6：返回结构化结果
        return {
            "question": question,
            "context": context,
            "answer": answer
        }

## 4. 端到端演示

In [None]:
# 定义示例文档数据
documents = [
    "人工智能（Artificial Intelligence，AI）是指由人类制造出来的机器所表现出来的智能。通常人工智能是指通过普通计算机程序来呈现人类智能的技术。",
    "机器学习是人工智能的一个分支，它使计算机能够从数据中学习并做出决策或预测，而无需明确编程来执行特定任务。",
    "深度学习是机器学习的一个子集，它模仿人脑的工作方式，使用神经网络来处理和学习复杂的数据模式。深度学习在图像识别、语音识别和自然语言处理等领域取得了显著成果。",
    "自然语言处理（NLP）是计算机科学和人工智能领域的一个重要方向，它致力于让计算机理解和生成人类语言。",
    "计算机视觉是一门研究如何让计算机'看'的科学，它赋予了计算机理解和解释图像和视频内容的能力。"
]

# 实例化SimpleRAG类
rag = SimpleRAG()

# 调用ingest方法，将示例文档数据加载到RAG系统中
rag.ingest(documents)

# 定义示例问题
question = "请解释一下什么是深度学习，它与机器学习有什么关系？"

# 调用query方法，并打印返回的结果
result = rag.query(question)

print(f"问题：{result['question']}")
print(f"\n检索到的上下文：\n{result['context']}")
print(f"\n生成的答案：\n{result['answer']}")