In [3]:
import os
import torch
from fastapi import FastAPI, HTTPException
from sentence_transformers import SentenceTransformer, CrossEncoder
from pymilvus import MilvusClient
from dotenv import load_dotenv
from contextlib import asynccontextmanager
from typing import Dict, Union
from rag_schemas import SearchRequest, ScrapeRequest, SearchResult, ScrapeResult
from rag_config import RAGConfig

load_dotenv()


def format_queries(query, instruction=None):
    prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. \
        Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
    if instruction is None:
        instruction = (
            "Given a web search query, retrieve relevant passages that answer the query"
        )
    return f"{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"


def format_document(document):
    suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
    return f"<Document>: {document}{suffix}"


# 任务描述
TASK = "Given a web search query, retrieve relevant passages that answer the query"

config = RAGConfig()

print("正在加载模型和元数据...")


print(f"正在连接Milvus: {config.milvus_endpoint}")
client = MilvusClient(uri=config.milvus_endpoint, token=config.milvus_token)

# 检查Collection是否存在
if not client.has_collection(collection_name=config.milvus_collection_name):
    raise SystemExit(
        f"错误: Milvus中不存在名为 '{config.milvus_collection_name}' 的集合。请先运行`build_datasets.ipynb`脚本来创建集合。"
    )



正在加载模型和元数据...
正在连接Milvus: https://in03-fe6558fabb7b567.serverless.ali-cn-hangzhou.cloud.zilliz.com.cn


In [10]:
embedding_model = SentenceTransformer(
        f"Qwen/{config.embedding_model_name}",
        cache_folder=f".{config.cache_dir}/{config.embedding_model_name}",
        model_kwargs={
            # "attn_implementation": "flash_attention_2",
            # "torch_dtype": torch.bfloat16,
        },
        tokenizer_kwargs={"padding_side": "left"},
        device="cuda:4",
    )

    # 加载用于重排序的Cross-Encoder模型
reranker_model = CrossEncoder(
        f"tomaarsen/{config.reranker_model_name}",
        cache_folder=f".{config.cache_dir}/{config.reranker_model_name}",
        model_kwargs={
            # "attn_implementation": "flash_attention_2",
            # "torch_dtype": torch.bfloat16,
        },
        tokenizer_kwargs={
            "padding_side": "left",
        },
        device="cuda:4",
    )

In [11]:
query_list = [
     "What is the capital of France?",
     "Explain the theory of relativity.",
 ]
query_vector = embedding_model.encode_query(
        query_list, normalize_embeddings=True
    )

In [26]:
results = client.search(
    collection_name=config.milvus_collection_name,
    data=query_vector,
    anns_field="vector",  # 您在Milvus中定义的向量字段名
    search_params={"metric_type": "IP", "params": {}},
    limit=config.top_k_retrieval,  # 返回的文档数量
    output_fields=["text", "title", "id"],  # 输出字段包括文本和ID和标题
)

print(f"查询结果数量: {len(results)}")
print(f"每个查询的文档数量: {[len(docs) for docs in results]}")

查询结果数量: 2
每个查询的文档数量: [50, 50]


In [None]:
# 针对每一个查询去重
unique_docs = []
for result in results:
    temp = []
    seen_titles = set()
    for hit in result:
        title = hit["entity"]["title"]
        if title not in seen_titles:
            seen_titles.add(title)
            temp.append(
                {
                    "id": hit["entity"]["id"],  # 保存原始ID
                    "title": title,
                    "text": hit["entity"]["text"],
                }
            )
    unique_docs.append(temp)

print(f"查询结果数量: {len(unique_docs)}")
print(f"每个查询的文档数量: {[len(docs) for docs in unique_docs]}")

查询结果数量: 2
每个查询的文档数量: [6, 8]
查询结果示例: [{'id': '2hop__152921_19219_0', 'title': 'Paris', 'text': "Paris (French pronunciation: \u200b (paʁi) (listen)) is the capital and most populous city in France, with an administrative - limits area of 105 square kilometres (41 square miles) and an official population of 2,206,488 (2015). The city is a commune and department, and the heart of the 12,012 - square - kilometre (4,638 - square - mile) Île - de-France region (colloquially known as the 'Paris Region'), whose 2016 population of 12,142,802 represented roughly 18 percent of the population of France. Since the 17th century, Paris has been one of Europe's major centres of finance, commerce, fashion, science, and the arts. The Paris Region had a GDP of €649.6 billion (US $763.4 billion) in 2014, accounting for 30.4 percent of the GDP of France. According to official estimates, in 2013 - 14 the Paris Region had the third - highest GDP in the world and the largest regional GDP in the EU."}, {'id': 

In [28]:
query_for_reranker = [format_queries(query, TASK) for query in query_list]

documents_for_reranker = []
for docs in unique_docs:
    documents_for_reranker.append(format_document(doc["text"]) for doc in docs)


len(query_for_reranker), len(documents_for_reranker)

(2, 2)

In [29]:
# **重排序 (Reranking)
# 使用Cross-Encoder模型计算相关性分数
ranks = [
    reranker_model.rank(
        query=query,
        documents=doc,
    )
    for query, doc in zip(query_for_reranker, documents_for_reranker)
]

ranks

[[{'corpus_id': 0, 'score': 0.99848026},
  {'corpus_id': 3, 'score': 0.9896303},
  {'corpus_id': 1, 'score': 0.91427857},
  {'corpus_id': 2, 'score': 0.7298417},
  {'corpus_id': 4, 'score': 0.0075455094},
  {'corpus_id': 5, 'score': 0.0013671208}],
 [{'corpus_id': 0, 'score': 0.9959792},
  {'corpus_id': 1, 'score': 0.890301},
  {'corpus_id': 6, 'score': 0.43265003},
  {'corpus_id': 5, 'score': 0.2497432},
  {'corpus_id': 2, 'score': 0.050744627},
  {'corpus_id': 3, 'score': 0.039595332},
  {'corpus_id': 4, 'score': 0.013334567},
  {'corpus_id': 7, 'score': 0.009522346}]]

In [33]:
# 格式化并返回最终结果
final_results = []
for index, rank in enumerate(ranks):
    temp = []
    # 只处理分数最高的TOP_K_RERANKED个结果
    for rank_info in rank[: config.top_k_rerank]:
        # `corpus_id`是reranker输入列表的索引，也对应我们`unique_docs`列表的索引
        corpus_id = rank_info["corpus_id"]

        # 从我们保存的原始文档列表中获取完整信息
        original_doc = unique_docs[index][corpus_id]

        # 构建最终返回的SearchResult对象
        temp.append(
            SearchResult(
                id=original_doc["id"],
                title=original_doc["title"],
                # 从原始文本生成预览
                preview=original_doc["text"][:150] + "...",
                rerank_score=rank_info["score"],
            )
        )
    final_results.append(temp)

final_results

[[SearchResult(id='2hop__152921_19219_0', title='Paris', preview='Paris (French pronunciation: \u200b (paʁi) (listen)) is the capital and most populous city in France, with an administrative - limits area of 105 square ki...', rerank_score=0.9984802603721619),
  SearchResult(id='3hop1__238983_403313_22725_8', title='Strasbourg', preview='Strasbourg (/ˈstræzbɜːrɡ/, French pronunciation: \u200b[stʁaz.buʁ, stʁas.buʁ]; Alsatian: Strossburi; German: Straßburg, [ˈʃtʁaːsbʊɐ̯k]) is the capital and ...', rerank_score=0.9896302819252014),
  SearchResult(id='3hop1__161098_80462_90300_1', title='Marseille', preview='Marseille (English / mɑːrˈseɪ /; French: (maʁsɛj) (listen), locally: (mɑχˈsɛjə); Provençal Marselha (maʀˈsejɔ, maʀˈsijɔ)), also known as Marseilles, i...', rerank_score=0.9142785668373108),
  SearchResult(id='2hop__661_26316_14', title='French Second Republic', preview="French Republic République française 1848 -- 1852 Flag Great Seal Motto Liberté, Égalité, Fraternité ``Liberty, Equal

In [39]:
doc_id = ["2hop__152921_19219_0", "2hop__782306_3089_3"]

results = client.query(
    collection_name=config.milvus_collection_name,
    ids=doc_id,
    output_fields=["text", "title"],
)
if not results:
    # 如果查询结果为空，说明该ID在数据库中不存在
    raise HTTPException(
        status_code=404, detail=f"文档ID {doc_id} 在数据库中未找到。"
    )

results

data: ['{\'text\': "Paris (French pronunciation: \\u200b (paʁi) (listen)) is the capital and most populous city in France, with an administrative - limits area of 105 square kilometres (41 square miles) and an official population of 2,206,488 (2015). The city is a commune and department, and the heart of the 12,012 - square - kilometre (4,638 - square - mile) Île - de-France region (colloquially known as the \'Paris Region\'), whose 2016 population of 12,142,802 represented roughly 18 percent of the population of France. Since the 17th century, Paris has been one of Europe\'s major centres of finance, commerce, fashion, science, and the arts. The Paris Region had a GDP of €649.6 billion (US $763.4 billion) in 2014, accounting for 30.4 percent of the GDP of France. According to official estimates, in 2013 - 14 the Paris Region had the third - highest GDP in the world and the largest regional GDP in the EU.", \'title\': \'Paris\', \'id\': \'2hop__152921_19219_0\'}', '{\'text\': \'General

In [42]:
tmp = [ScrapeResult(id=doc["id"], title=doc["title"], full_text=doc["text"]) for doc in results]
tmp

[ScrapeResult(id='2hop__152921_19219_0', title='Paris', full_text="Paris (French pronunciation: \u200b (paʁi) (listen)) is the capital and most populous city in France, with an administrative - limits area of 105 square kilometres (41 square miles) and an official population of 2,206,488 (2015). The city is a commune and department, and the heart of the 12,012 - square - kilometre (4,638 - square - mile) Île - de-France region (colloquially known as the 'Paris Region'), whose 2016 population of 12,142,802 represented roughly 18 percent of the population of France. Since the 17th century, Paris has been one of Europe's major centres of finance, commerce, fashion, science, and the arts. The Paris Region had a GDP of €649.6 billion (US $763.4 billion) in 2014, accounting for 30.4 percent of the GDP of France. According to official estimates, in 2013 - 14 the Paris Region had the third - highest GDP in the world and the largest regional GDP in the EU."),
 ScrapeResult(id='2hop__782306_3089

In [13]:
query = ["世界上最高的山峰是什么？", "谁是美国的第一任总统？"]
query_vector = embedding_model.encode_query(
        query, normalize_embeddings=True
    )

query_vector

array([[-0.01220703,  0.04760742, -0.00473022, ...,  0.00180054,
        -0.02770996,  0.00744629],
       [-0.0088501 ,  0.00982666,  0.00156403, ...,  0.02624512,
        -0.04174805,  0.0001545 ]], shape=(2, 1024), dtype=float32)

In [14]:
# 在Milvus中进行搜索
results = client.search(
    collection_name=MILVUS_COLLECTION_NAME,
    data=query_vector,
    anns_field="vector",  # 您在Milvus中定义的向量字段名
    search_params={"metric_type": "IP", "params": {}},
    limit=50,
    output_fields=["text", "title"],  # 输出字段包括文本和ID
)
results

data: [[{'id': '2hop__47005_159536_16', 'distance': 0.7016823291778564, 'entity': {'text': "Chimborazo, 6,267 metres (20,561 ft). Presumed highest from sixteenth century until the beginning of the 19th century. Not in the top 100 highest mountains when measured from sea level, however due to the earth's equatorial bulge this is the farthest point from the Earth's center. Nanda Devi, 7,816 metres (25,643 ft). Presumed highest in the world before Kangchenjunga was sighted in an era when Nepal was still closed to the outside world. Now known to be the 23rd highest mountain in the world. Dhaulagiri, 8,167 metres (26,795 ft). Presumed highest from 1808 until 1847. Now known to be the 7th highest mountain in the world. Kangchenjunga, 8,586 metres (28,169 ft). Presumed highest from 1847 until 1852. Now known to be the 3rd highest mountain in the world. Mount Everest, 8,848 metres (29,029 ft). Established as highest in 1852 and officially confirmed in 1856. K2, 8,611 metres (28,251 ft). Discov

In [15]:
len(results)

2

In [46]:
seen_texts = set()
text = [item["entity"]["text"] for item in results[0] if item["entity"]["title"] not in seen_texts and not seen_texts.add(item["entity"]["title"])]

text

["Chimborazo, 6,267 metres (20,561 ft). Presumed highest from sixteenth century until the beginning of the 19th century. Not in the top 100 highest mountains when measured from sea level, however due to the earth's equatorial bulge this is the farthest point from the Earth's center. Nanda Devi, 7,816 metres (25,643 ft). Presumed highest in the world before Kangchenjunga was sighted in an era when Nepal was still closed to the outside world. Now known to be the 23rd highest mountain in the world. Dhaulagiri, 8,167 metres (26,795 ft). Presumed highest from 1808 until 1847. Now known to be the 7th highest mountain in the world. Kangchenjunga, 8,586 metres (28,169 ft). Presumed highest from 1847 until 1852. Now known to be the 3rd highest mountain in the world. Mount Everest, 8,848 metres (29,029 ft). Established as highest in 1852 and officially confirmed in 1856. K2, 8,611 metres (28,251 ft). Discovered in 1856 before Mt. Everest was officially confirmed, K2's elevation became something 

In [50]:
def format_queries(query, instruction=None):
    prefix = '<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. \
        Note that the answer can only be "yes" or "no".<|im_end|>\n<|im_start|>user\n'
    if instruction is None:
        instruction = (
            "Given a web search query, retrieve relevant passages that answer the query"
        )
    return f"{prefix}<Instruct>: {instruction}\n<Query>: {query}\n"


def format_document(document):
    suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
    return f"<Document>: {document}{suffix}"


task = "Given a web search query, retrieve relevant passages that answer the query"

query = format_queries(query, task)
documents = [format_document(doc) for doc in text]

ranks = reranker_model.rank(
    query=query, documents=text, convert_to_tensor=True, return_documents=True
)

ranks

[{'corpus_id': 4,
  'score': tensor(0.6836, device='cuda:0', dtype=torch.bfloat16),
  'text': 'Tongshanjiabu is a mountain in the Himalayas. At tall, Tongshanjiabu is the 103rd tallest mountain in the world. It sits in the disputed border territory between Bhutan and China. Tongshanjiabu has never been officially climbed.'},
 {'corpus_id': 1,
  'score': tensor(0.6602, device='cuda:0', dtype=torch.bfloat16),
  'text': "Mount Everest, known in Nepali as Sagarmatha (सगरमाथा) and in Tibetan as Chomolungma (ཇོ ་ མོ ་ གླང ་ མ), is Earth's highest mountain above sea level, located in the Mahalangur Himal sub-range of the Himalayas. The international border between Nepal (Province No. 1) and China (Tibet Autonomous Region) runs across its summit point."},
 {'corpus_id': 6,
  'score': tensor(0.6523, device='cuda:0', dtype=torch.bfloat16),
  'text': 'The Matterhorn (German: Matterhorn, (ˈmatərˌhɔrn); Italian: Cervino, (ˈtʃerˈviːno); French: Le Cervin, (mɔ̃ sɛʁvɛ̃)) is a mountain of the Alps, str

In [48]:
ranks = reranker_model.rank(
    query=query, documents=documents, convert_to_tensor=True, return_documents=True
)
ranks

[{'corpus_id': 0,
  'score': tensor(0.9922, device='cuda:0', dtype=torch.bfloat16),
  'text': "<Document>: Chimborazo, 6,267 metres (20,561 ft). Presumed highest from sixteenth century until the beginning of the 19th century. Not in the top 100 highest mountains when measured from sea level, however due to the earth's equatorial bulge this is the farthest point from the Earth's center. Nanda Devi, 7,816 metres (25,643 ft). Presumed highest in the world before Kangchenjunga was sighted in an era when Nepal was still closed to the outside world. Now known to be the 23rd highest mountain in the world. Dhaulagiri, 8,167 metres (26,795 ft). Presumed highest from 1808 until 1847. Now known to be the 7th highest mountain in the world. Kangchenjunga, 8,586 metres (28,169 ft). Presumed highest from 1847 until 1852. Now known to be the 3rd highest mountain in the world. Mount Everest, 8,848 metres (29,029 ft). Established as highest in 1852 and officially confirmed in 1856. K2, 8,611 metres (28,

In [58]:
result = client.query(
        collection_name=MILVUS_COLLECTION_NAME,
        ids="2hop__47005_159536_16",
        output_fields=["text", "title"],
    )
result

data: ['{\'text\': "Chimborazo, 6,267 metres (20,561 ft). Presumed highest from sixteenth century until the beginning of the 19th century. Not in the top 100 highest mountains when measured from sea level, however due to the earth\'s equatorial bulge this is the farthest point from the Earth\'s center. Nanda Devi, 7,816 metres (25,643 ft). Presumed highest in the world before Kangchenjunga was sighted in an era when Nepal was still closed to the outside world. Now known to be the 23rd highest mountain in the world. Dhaulagiri, 8,167 metres (26,795 ft). Presumed highest from 1808 until 1847. Now known to be the 7th highest mountain in the world. Kangchenjunga, 8,586 metres (28,169 ft). Presumed highest from 1847 until 1852. Now known to be the 3rd highest mountain in the world. Mount Everest, 8,848 metres (29,029 ft). Established as highest in 1852 and officially confirmed in 1856. K2, 8,611 metres (28,251 ft). Discovered in 1856 before Mt. Everest was officially confirmed, K2\'s elevat

In [61]:
result[0]

{'text': "Chimborazo, 6,267 metres (20,561 ft). Presumed highest from sixteenth century until the beginning of the 19th century. Not in the top 100 highest mountains when measured from sea level, however due to the earth's equatorial bulge this is the farthest point from the Earth's center. Nanda Devi, 7,816 metres (25,643 ft). Presumed highest in the world before Kangchenjunga was sighted in an era when Nepal was still closed to the outside world. Now known to be the 23rd highest mountain in the world. Dhaulagiri, 8,167 metres (26,795 ft). Presumed highest from 1808 until 1847. Now known to be the 7th highest mountain in the world. Kangchenjunga, 8,586 metres (28,169 ft). Presumed highest from 1847 until 1852. Now known to be the 3rd highest mountain in the world. Mount Everest, 8,848 metres (29,029 ft). Established as highest in 1852 and officially confirmed in 1856. K2, 8,611 metres (28,251 ft). Discovered in 1856 before Mt. Everest was officially confirmed, K2's elevation became so