In [1]:
#加载数据集
import json

data = []
with open('../data/crag_data_200.jsonl', 'r') as file:
    for line in file:
        data.append(json.loads(line))

print(f"Loaded {len(data)} records.")


Loaded 200 records.


In [2]:
import os

# 获取并打印当前工作目录
current_working_directory = os.getcwd()
print(f"当前工作目录是: {current_working_directory}")


当前工作目录是: /home/u2021213565/jupyterlab/RAG2406/RAG_HBX


In [3]:
from transformers import AutoTokenizer, AutoModel
import torch
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.utils import embedding_functions

model_name = "BAAI/bge-small-en-v1.5"
local_model_path = "./model/checkpoint-15000"

class MyEmbeddingFunction(EmbeddingFunction):
    def __init__(self, model_path: str):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModel.from_pretrained(model_path)

    def __call__(self, input: Documents) -> Embeddings:
        inputs = self.tokenizer(input, return_tensors='pt', padding=True, truncation=True)
        with torch.no_grad():
            outputs = self.model(**inputs)
        # Assuming the embeddings are in the last hidden state and taking the mean pooling
        embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
        return embeddings.tolist()


# 使用你的bge-small-en-v1.5模型创建自定义嵌入函数
my_embedding_function = MyEmbeddingFunction(model_path=local_model_path)



In [5]:
#创建数据库
import chromadb

# Initialize Chroma client
# client = chromadb.Client()
client = chromadb.PersistentClient(path="./chroma_pipeline")
# Create a new collection
collection = client.create_collection(name="collection_embedding", embedding_function=my_embedding_function)
# collection = client.create_collection('crag_documents')


In [6]:
from langchain_text_splitters import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=64, chunk_overlap=2, add_start_index=True
)

In [7]:
# 准备并插入数据到Chroma
from tqdm import tqdm

documents = []
metadatas = []
ids = []

for i, record in enumerate(tqdm(data, desc="Processing records")):
    query = record["query"]
    answer = record["answer"]
    search_results = record["search_results"]
    
    # 创建一个包含query和answer的文档字符串，并添加search_results的内容
    # document = query + " " + answer + " " + " ".join([result["page_snippet"] for result in search_results])
    text = ''.join([result["page_snippet"] for result in search_results])
    doc=text_splitter.create_documents([text])
    all_splits = text_splitter.split_documents(doc)
    
    # 遍历所有分块并分别添加到相应的列表中
    for j, split in enumerate(all_splits):
        documents.append(split.page_content)
        metadatas.append({
            "query": query,
            "answer": answer,
            "split_index": j
        })
        ids.append(f"doc_{i}_split_{j}")

# 向集合中添加文档
collection.add(
    documents=documents,  # 自动处理分词、嵌入和索引
    metadatas=metadatas,  # 可以根据这些元数据进行过滤
    ids=ids  # 每个文档的唯一标识
)

print("Data has been inserted into Chroma.")


Processing records: 100%|██████████| 200/200 [00:01<00:00, 118.41it/s]


Data has been inserted into Chroma.


In [8]:
# 测试用例

# 查询
query = "which is a better investment, gold or silver, when considering long-term return?"

# 执行查询，先筛选出包含query的所有文本块，再获取最相似的三个结果
results = collection.query(
    query_texts=[query],
    n_results=5,
    where={"query": query},  # 筛选包含该query的所有文本块
)

print("Top 3 relevant documents:")
context = ""
for result in results['documents']:
    for doc in result:  # 遍历列表中的每一个文档
        print(doc)
        context += doc + "\n"

print("Context:")
print(context)


Top 3 relevant documents:
Is gold or silver the best choice for investing? We compare the
considering silver as an investment option then you would be
considering silver as an investment option then you would be
investment option.Though, as an investment, silver is not as
as an investment, silver is not as popular as gold, it is in
Context:
Is gold or silver the best choice for investing? We compare the
considering silver as an investment option then you would be
considering silver as an investment option then you would be
investment option.Though, as an investment, silver is not as
as an investment, silver is not as popular as gold, it is in



In [11]:
import requests
# 设置百度API密钥
API_KEY = "r3eHPr2Da9rJa1yJpd5qCGy3"
SECRET_KEY = "Wyi1aOmWcjt70roUhQS4v2GbOvmfGXHn"

def get_access_token():
    """
    使用 API Key，Secret Key 获取access_token
    """
    url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={API_KEY}&client_secret={SECRET_KEY}"
    response = requests.post(url)
    return response.json().get("access_token")

def call_baidu_chat_api(query, context):
    access_token = get_access_token()
    url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/yi_34b_chat?access_token={access_token}"
    payload = json.dumps({
        "messages": [
            {
                "role": "user",
                "content": f"You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Answer in 2 to 3 sentences. Context: {context} Question: {query}",
            }
        ]
    })
    headers = {
        'Content-Type': 'application/json'
    }
    response = requests.post(url, headers=headers, data=payload)
    return response.json()

In [12]:
# 测试用例
query = "Who was the first president of the United States?"
context = """
The first president of the United States was George Washington. He served two terms as president from 1789 to 1797. Washington is often referred to as the "Father of His Country" for his pivotal role in the founding of the United States. Before becoming president, he served as the commander-in-chief of the Continental Army during the American Revolutionary War and presided over the convention that drafted the U.S. Constitution.
"""
# 调用百度API获取答案
response = call_baidu_chat_api(query, context)
print("Response from Baidu API:")
print(response)
if "result" in response and response["result"]:
    print(response["result"])

Response from Baidu API:
{'id': 'as-2vuzgypb0s', 'object': 'chat.completion', 'created': 1719303448, 'result': 'The first president of the United States was George Washington. He served two terms as president from 1789 to 1797. Washington is often referred to as the "Father of His Country" for his pivotal role in the founding of the United States. Before becoming president, he served as the commander-in-chief of the Continental Army during the American Revolutionary War and presided over the convention that drafted the U.S. Constitution.', 'is_truncated': False, 'need_clear_history': False, 'usage': {'prompt_tokens': 162, 'completion_tokens': 94, 'total_tokens': 256}}
The first president of the United States was George Washington. He served two terms as president from 1789 to 1797. Washington is often referred to as the "Father of His Country" for his pivotal role in the founding of the United States. Before becoming president, he served as the commander-in-chief of the Continental Arm

In [13]:
def query_database(query):
    results = collection.query(
    query_texts=[query],
    n_results=5,
    where={"query": query},  # 筛选包含该query的所有文本块
                        )
    context = ""
    for result in results['documents']:
        for doc in result:  # 遍历列表中的每一个文档
            context += doc + "\n"

    return context

In [14]:
#测试用例，先问它个5回
for i, record in enumerate(tqdm(data, desc="Processing records")):
    query = record["query"]
    answer = record["answer"]
    search_results = record["search_results"]
    context = query_database(query)
    response = call_baidu_chat_api(query, context)
    print(f"Query: {query}")
    print(f"Answer: {answer}")
    pred = response['result']
    print(f"Response from Baidu API: {pred}")
    if i > 5:
        break

Processing records:   0%|          | 1/200 [00:04<15:01,  4.53s/it]

Query: which is a better investment, gold or silver, when considering long-term return?
Answer: gold
Response from Baidu API: When considering long-term return, gold is generally considered the better investment option compared to silver. Gold has a history of maintaining its value over time and is often used as a hedge against inflation. While silver can also be a good investment, it is more volatile and subject to more price fluctuations than gold, making it a riskier choice for long-term investment purposes.


Processing records:   1%|          | 2/200 [00:07<11:33,  3.50s/it]

Query: which country is the largest gold producer?
Answer: china
Response from Baidu API: According to the provided context, the largest gold producing country in the world is China. It produced 403 metric tons of gold in 2012, followed by Australia and Russia.


Processing records:   2%|▏         | 3/200 [00:11<12:51,  3.92s/it]

Query: what is the name of priyanka chopra's fashion line?
Answer: invalid question
Response from Baidu API: I'm sorry, but I don't have information about Priyanka Chopra having her own fashion line. She is known for her fashion sense and has been a prominent figure in the entertainment industry, but I'm not aware of her having a specific brand or line of clothing. If you have any other questions, I'll be happy to help!


Processing records:   2%|▏         | 4/200 [00:14<11:18,  3.46s/it]

Query: who has won more atp singles titles, roger federer or rafael nadal?
Answer: rafael nadal has won more atp singles titles than roger federer, with 86 titles compared to federer's 82.
Response from Baidu API: Roger Federer has won more ATP singles titles than Rafael Nadal. Federer has won 103 ATP titles in his career, while Nadal has won 85 ATP titles.


Processing records:   2%|▎         | 5/200 [00:18<11:48,  3.63s/it]

Query: who was responsible for initiating the construction of the badshahi mosque?
Answer: aurangzeb
Response from Baidu API: The construction of the Badshahi Mosque was initiated by Emperor Aurangzeb, the sixth Mughal Emperor, who reigned from 1658 to 1707. He was responsible for commissioning the mosque, which was built between 1671 and 1673.


Processing records:   3%|▎         | 6/200 [00:23<13:09,  4.07s/it]

Query: what is the price of bitcoin when it launch in 2015?
Answer: invalid question
Response from Baidu API: The price of Bitcoin when it launched in 2015 was under $5. However, it's important to note that the question seems to be confused about the launch date of Bitcoin. Bitcoin was first introduced in 2009, not 2015. The price of Bitcoin in 2009 was very different from what it was in 2015.


Processing records:   3%|▎         | 6/200 [00:25<13:41,  4.23s/it]

Query: how much rings does steve kerr hold?
Answer: kerr is a nine-time nba champion
Response from Baidu API: Steve Kerr holds 9 rings in total, 5 of them as a player and 4 as a coach.





In [None]:
result = []
for i, record in enumerate(tqdm(data, desc="Processing records")):
    query = record["query"]
    answer = record["answer"]
    search_results = record["search_results"]
    context = query_database(query)
    response = call_baidu_chat_api(query, context)
    pred = response['result']
    result.append(json.dumps({'query': query, 'answer': answer, 'pred': pred}, ensure_ascii=False) + '\n')
with open("../data/baidu_chat_results.jsonl", 'w', encoding='utf-8') as f:
    f.write(''.join(result))

In [15]:
def continue_processing(data, target_file='./baidu_chat_results.jsonl'):
    # 检查目标文件的目录是否存在，不存在则创建
    target_dir = os.path.dirname(target_file)
    if not os.path.exists(target_dir):
        os.makedirs(target_dir)

    # 读取已有文件的行数，确定从哪一行开始继续处理
    if os.path.exists(target_file):
        with open(target_file, 'r', encoding='utf-8') as f:
            processed_lines = len(f.readlines())
    else:
        processed_lines = 0

    result = []
    for i, record in enumerate(tqdm(data[processed_lines:], desc="Processing records", initial=processed_lines, total=len(data))):
        query = record["query"]
        answer = record["answer"]
        search_results = record["search_results"]
        context = query_database(query)
        response = call_baidu_chat_api(query, context)
        if 'result' in response:
            pred = response['result']
        else:
            print(f"Missing 'result' key in response for query: {query}")
            pred = None
        result.append(json.dumps({'query': query, 'answer': answer, 'pred': pred}, ensure_ascii=False) + '\n')

        # 每处理一条记录就立即写入文件，防止中途出错数据丢失
        with open(target_file, 'a', encoding='utf-8') as f:
            f.write(result[-1])

In [None]:
import os
continue_processing(data)

Processing records:   4%|▍         | 8/200 [00:31<12:27,  3.90s/it]