In [1]:
#加载数据集
import json

data = []
with open('../data/crag_data_200.jsonl', 'r') as file:
    for line in file:
        data.append(json.loads(line))

print(f"Loaded {len(data)} records.")


Loaded 200 records.


In [2]:
import os

# 获取并打印当前工作目录
current_working_directory = os.getcwd()
print(f"当前工作目录是: {current_working_directory}")


当前工作目录是: /home/bangx/RAG_2406/RAG_HBX


In [4]:
from transformers import AutoTokenizer, AutoModel
import torch
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.utils import embedding_functions

model_name = "BAAI/bge-small-en-v1.5"
local_model_path = "./model-qxk/checkpoint-15000"

class MyEmbeddingFunction(EmbeddingFunction):
    def __init__(self, model_path: str):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModel.from_pretrained(model_path)

    def __call__(self, input: Documents) -> Embeddings:
        inputs = self.tokenizer(input, return_tensors='pt', padding=True, truncation=True)
        with torch.no_grad():
            outputs = self.model(**inputs)
        # Assuming the embeddings are in the last hidden state and taking the mean pooling
        embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
        return embeddings.tolist()


# 使用你的bge-small-en-v1.5模型创建自定义嵌入函数
my_embedding_function = MyEmbeddingFunction(model_path=local_model_path)



In [5]:
#创建数据库
import chromadb

# Initialize Chroma client
# client = chromadb.Client()
client = chromadb.PersistentClient(path="./chroma_pipeline_qxk")
# Create a new collection
collection = client.get_collection(name="collection_embedding", embedding_function=my_embedding_function)
# collection = client.create_collection('crag_documents')


In [6]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=64, chunk_overlap=2, add_start_index=True
)

In [7]:
# 准备并插入数据到Chroma
from tqdm import tqdm

documents = []
metadatas = []
ids = []

for i, record in enumerate(tqdm(data, desc="Processing records")):
    query = record["query"]
    answer = record["answer"]
    search_results = record["search_results"]
    
    # 创建一个包含query和answer的文档字符串，并添加search_results的内容
    # document = query + " " + answer + " " + " ".join([result["page_snippet"] for result in search_results])
    text = ''.join([result["page_snippet"] for result in search_results])
    doc=text_splitter.create_documents([text])
    all_splits = text_splitter.split_documents(doc)
    
    # 遍历所有分块并分别添加到相应的列表中
    for j, split in enumerate(all_splits):
        documents.append(split.page_content)
        metadatas.append({
            "query": query,
            "answer": answer,
            "split_index": j
        })
        ids.append(f"doc_{i}_split_{j}")

# 向集合中添加文档
collection.add(
    documents=documents,  # 自动处理分词、嵌入和索引
    metadatas=metadatas,  # 可以根据这些元数据进行过滤
    ids=ids  # 每个文档的唯一标识
)

print("Data has been inserted into Chroma.")


Processing records: 100%|██████████| 200/200 [00:01<00:00, 118.41it/s]


Data has been inserted into Chroma.


In [16]:
# 测试用例

# 查询
query = "which is a better investment, gold or silver, when considering long-term return?"

# 执行查询，先筛选出包含query的所有文本块，再获取最相似的三个结果
results = collection.query(
    query_texts=[query],
    n_results=15,
    where={"query": query},  # 筛选包含该query的所有文本块
)

print("Top 3 relevant documents:")
context = ""
for result in results['documents']:
    for doc in result:  # 遍历列表中的每一个文档
        print(doc)
        context += doc + "\n"

print("Context:")
print(context)


Top 3 relevant documents:
Is gold or silver the best choice for investing? We compare the
as an investment, silver is not as popular as gold, it is in
considering silver as an investment option then you would be
considering silver as an investment option then you would be
investment option.Though, as an investment, silver is not as
choosing investments can be difficult. Some want Gold or Silver
investing in silver can be a wise choice. You can also avail
silver. To put it into perspective, if you invested £50,000 in
for this precious metal. Hence, investing in silver can be a
between silver and gold, if you bought both in equal monetary
risk. ... Silver can be considered a good portfolio diversifier
returns. However, if you are considering silver as an
silver. Silver prices have higher price volatility and can be
you invested £50,000 in silver, you would need approximately
investments can be difficult. Some want Gold or Silver coins,
Context:
Is gold or silver the best choice for inves

In [9]:
import os
from openai import OpenAI

# 设置API密钥
os.environ['OPENAI_API_KEY'] = 

# 设置TOKENIZERS_PARALLELISM环境变量
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

# 实例化OpenAI客户端
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

def query_openai(prompt, model="gpt-3.5-turbo", max_tokens=100):
    """
    调用OpenAI API进行查询并获取响应。

    :param prompt: 输入的提示文本
    :param model: 使用的模型，默认为gpt-3.5-turbo
    :param max_tokens: 响应的最大token数，默认为100
    :return: OpenAI模型的响应文本
    """
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ],
        max_tokens=max_tokens,
        temperature=0.7,
    )
    return response.choices[0].message.content.strip()

# 示例查询
prompt = "What is task decomposition for LLM agents?"
try:
    response = query_openai(prompt)
    print("Response from OpenAI:", response)
except Exception as e:
    print(f"An error occurred: {e}")

Response from OpenAI: Task decomposition for Large Language Models (LLMs) involves breaking down a complex task into smaller, more manageable sub-tasks that the model can process sequentially or in parallel. This approach allows the LLM to address tasks that require multiple steps or components by dividing them into simpler tasks that the model can handle more effectively.

By decomposing tasks, LLM agents can achieve better performance, reduce computational complexity, and improve efficiency in handling a wide range of tasks. Task decomposition can also help LLMs to


In [13]:
def call_openai_api(query, context):
    prompt=f"You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Answer in 2 to 3 sentences. \nQuery: {query}\nContext: {context}\n"

    try:
        response = query_openai(prompt)
    except Exception as e:
        print(f"An error occurred: {e}")
    
    return response


In [29]:
import requests
# 设置百度API密钥


def get_access_token():
    """
    使用 API Key，Secret Key 获取access_token
    """
    url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={API_KEY}&client_secret={SECRET_KEY}"
    response = requests.post(url)
    return response.json().get("access_token")

def call_baidu_chat_api(query, context):
    access_token = get_access_token()
    url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat?access_token={access_token}"
    payload = json.dumps({
        "messages": [
            {
                "role": "user",
                "content": f"You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Answer in 2 to 3 sentences. Context: {context} Question: {query}",
            }
        ]
    })
    headers = {
        'Content-Type': 'application/json'
    }
    response = requests.post(url, headers=headers, data=payload)
    return response.json()

In [14]:
# 测试用例
query = "Who was the first president of the United States?"
context = """
The first president of the United States was George Washington. He served two terms as president from 1789 to 1797. Washington is often referred to as the "Father of His Country" for his pivotal role in the founding of the United States. Before becoming president, he served as the commander-in-chief of the Continental Army during the American Revolutionary War and presided over the convention that drafted the U.S. Constitution.
"""
# 调用百度API获取答案
response = call_openai_api(query, context)
print("Response from Baidu API:")
print(response)
if "result" in response and response["result"]:
    print(response["result"])

Response from Baidu API:
The first president of the United States was George Washington. He served two terms from 1789 to 1797 and is often referred to as the "Father of His Country" for his significant contributions to the founding of the nation.


In [19]:
def query_database(query):
    results = collection.query(
    query_texts=[query],
    n_results=15,
    where={"query": query},  # 筛选包含该query的所有文本块
                        )
    context = ""
    for result in results['documents']:
        for doc in result:  # 遍历列表中的每一个文档
            context += doc + "\n"

    return context

In [11]:
def query_database(query):
    results = collection.query(
        query_texts=[query],
        n_results=15,
        where={"query": query},  # 筛选包含该query的所有文本块
    )
    
    context = ""
    doc_count = {}  # 记录每个文档的文本块出现次数
    doc_texts = {}  # 记录每个文档的文本块内容

    for result in results['documents']:
        for doc in result:  # 遍历列表中的每一个文档
            doc_id = doc.split("_split_")[0]  # 提取doc_{i}部分
            split_index = int(doc.split("_split_")[1])  # 提取块索引
            if doc_id not in doc_count:
                doc_count[doc_id] = 0
                doc_texts[doc_id] = {}
            doc_count[doc_id] += 1
            doc_texts[doc_id][split_index] = doc
    
    # 阈值定义
    threshold = 3
    
    for doc_id, count in doc_count.items():
        if count >= threshold:
            # 找到最小和最大的split_index
            min_index = min(doc_texts[doc_id].keys())
            max_index = max(doc_texts[doc_id].keys())
            # 合并并补全同一文档的所有文本块
            merged_text = ""
            for i in range(min_index, max_index + 1):
                if i in doc_texts[doc_id]:
                    merged_text += doc_texts[doc_id][i] + "\n"
                else:
                    # 补全缺失的块
                    merged_text += f"{doc_id}_split_{i}\n"
            context += merged_text + "\n"
        else:
            # 直接添加到context中
            for index in sorted(doc_texts[doc_id].keys()):
                context += doc_texts[doc_id][index] + "\n"
    
    return context


In [17]:
def query_database(query):
    results = collection.query(
        query_texts=[query],
        n_results=15,
        where={"query": query},  # 筛选包含该query的所有文本块
    )
    
    context = ""
    doc_count = {}  # 记录每个文档的文本块出现次数
    doc_texts = {}  # 记录每个文档的文本块内容

    for result in results['documents']:
        for doc in result:  # 遍历列表中的每一个文档
            parts = doc.split("_split_")
            if len(parts) == 2:
                doc_id = parts[0]  # 提取doc_{i}部分
                try:
                    split_index = int(parts[1])  # 提取块索引
                except ValueError:
                    continue  # 忽略无法转换为整数的索引

                if doc_id not in doc_count:
                    doc_count[doc_id] = 0
                    doc_texts[doc_id] = {}
                doc_count[doc_id] += 1
                doc_texts[doc_id][split_index] = doc
            else:
                continue  # 忽略不符合预期格式的文档
    
    # 阈值定义
    threshold = 3
    
    for doc_id, count in doc_count.items():
        if count >= threshold:
            # 找到最小和最大的split_index
            min_index = min(doc_texts[doc_id].keys())
            max_index = max(doc_texts[doc_id].keys())
            # 合并并补全同一文档的所有文本块
            merged_text = ""
            for i in range(min_index, max_index + 1):
                if i in doc_texts[doc_id]:
                    merged_text += doc_texts[doc_id][i] + "\n"
                else:
                    # 补全缺失的块
                    merged_text += f"{doc_id}_split_{i}\n"
            context += merged_text + "\n"
        else:
            # 直接添加到context中
            for index in sorted(doc_texts[doc_id].keys()):
                context += doc_texts[doc_id][index] + "\n"
    
    return context


In [20]:
#测试用例，先问它个5回
from tqdm import tqdm

for i, record in enumerate(tqdm(data, desc="Processing records")):
    query = record["query"]
    answer = record["answer"]
    search_results = record["search_results"]
    context = query_database(query)
    response = call_openai_api(query, context)
    print(f"Query: {query}")
    print(f"Answer: {answer}")
    # pred = response['result']
    print(f"Response from Baidu API: {response}")
    if i > 1:
        break

Processing records:   0%|          | 1/200 [00:02<09:33,  2.88s/it]

Query: which is a better investment, gold or silver, when considering long-term return?
Answer: gold
Response from Baidu API: When considering long-term return, gold is generally considered a better investment compared to silver. Gold is more popular and often seen as a safer asset with more stable value over time. Silver, on the other hand, tends to have higher price volatility and may not offer the same level of long-term return as gold.


Processing records:   1%|          | 2/200 [00:03<05:56,  1.80s/it]

Query: which country is the largest gold producer?
Answer: china
Response from Baidu API: China is the largest gold producer in the world, producing 403 tonnes of gold alone.


Processing records:   1%|          | 2/200 [00:05<08:58,  2.72s/it]

Query: what is the name of priyanka chopra's fashion line?
Answer: invalid question
Response from Baidu API: I'm sorry, but based on the provided context, there is no specific mention of the name of Priyanka Chopra's fashion line.





In [None]:
result = []
for i, record in enumerate(tqdm(data, desc="Processing records")):
    query = record["query"]
    answer = record["answer"]
    search_results = record["search_results"]
    context = query_database(query)
    response = call_baidu_chat_api(query, context)
    pred = response['result']
    result.append(json.dumps({'query': query, 'answer': answer, 'pred': pred}, ensure_ascii=False) + '\n')
with open("../data/baidu_chat_results.jsonl", 'w', encoding='utf-8') as f:
    f.write(''.join(result))

In [27]:
def continue_processing(data, target_file='./baidu_chat_results.jsonl'):
    # 检查目标文件的目录是否存在，不存在则创建
    target_dir = os.path.dirname(target_file)
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    # 读取已有文件的行数，确定从哪一行开始继续处理
    if os.path.exists(target_file):
        with open(target_file, 'r', encoding='utf-8') as f:
            processed_lines = len(f.readlines())
    else:
        processed_lines = 0

    result = []
    for i, record in enumerate(tqdm(data[processed_lines:], desc="Processing records", initial=processed_lines, total=len(data))):
        query = record["query"]
        answer = record["answer"]
        search_results = record["search_results"]
        context = query_database(query)
        response = call_openai_api(query, context)
        if response:
            pred = response
        else:
            print(f"Missing 'result' key in response for query: {query}")
            pred = None
        result.append(json.dumps({'query': query, 'answer': answer, 'pred': pred}, ensure_ascii=False) + '\n')

        # 每处理一条记录就立即写入文件，防止中途出错数据丢失
        with open(target_file, 'a', encoding='utf-8') as f:
            f.write(result[-1])

In [28]:
import os
continue_processing(data)

Processing records: 100%|██████████| 200/200 [04:50<00:00,  1.45s/it]


In [25]:
def continue_processing(data, target_file='./baidu_chat_results.jsonl'):
    # 检查目标文件的目录是否存在，不存在则创建
    target_dir = os.path.dirname(target_file)
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    existing_data = []
    missing_results_indices = []

    # 读取已有文件的内容，找出pred为null的记录
    if os.path.exists(target_file):
        with open(target_file, 'r', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                record = json.loads(line)
                existing_data.append(record)
                if record['pred'] is None:
                    missing_results_indices.append(idx)

    # 处理缺失 result 的记录
    if missing_results_indices:
        print(f"Reprocessing {len(missing_results_indices)} missing results...")
        for idx in tqdm(missing_results_indices, desc="Reprocessing missing records"):
            record = existing_data[idx]
            query = record["query"]
            context = query_database(query)
            response = call_openai_api(query, context)
            if response:
                record['pred'] = response
            else:
                print(f"Still missing 'result' key in response for query: {query}")
                record['pred'] = None

    # 将所有处理后的记录按原始顺序写回文件
    with open(target_file, 'w', encoding='utf-8') as f:
        for record in existing_data:
            f.write(json.dumps(record, ensure_ascii=False) + '\n')