In [1]:
from pymilvus import MilvusClient
from pymilvus import connections, MilvusClient, DataType
from pymilvus import AnnSearchRequest

# 云端向量数据库
CLUSTER_ENDPOINT = "https://in03-9e9fdbef8863964.serverless.ali-cn-hangzhou.cloud.zilliz.com.cn" 
TOKEN = "0b26d6c6392802bee965421c29437b0777d93c49f4871a18586ab47f1f3e98b9e805a65e8aa0a49b12ce816684cd94c8d9cd4344"

client = MilvusClient(
    uri=CLUSTER_ENDPOINT,
    token=TOKEN 
)

In [2]:
client.list_collections()

['vector_demo', 'rag_demo']

In [3]:
client.drop_collection(collection_name="rag_demo")

## RAG案例

In [4]:
import pandas as pd
from tqdm import tqdm
news = pd.read_csv("http://mirror.coggle.club/news-title.txt", sep='\t', header=None)[0].drop_duplicates().values

In [5]:
from sentence_transformers import SentenceTransformer
local_model_path = "./bge/bge-small-zh-v1.5"

# # 在线加载模型
# from sentence_transformers import SentenceTransformer
# model = SentenceTransformer("BAAI/bge-small-zh-v1.5")

# 从本地加载模型
model = SentenceTransformer(local_model_path)

In [6]:
client.drop_collection(collection_name="rag_demo")

schema = MilvusClient.create_schema(
    auto_id=True,
    enable_dynamic_field=True,
)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="chunk_content", datatype=DataType.VARCHAR, max_length=1024)
schema.add_field(field_name="chunk_sparse_embedding", datatype=DataType.SPARSE_FLOAT_VECTOR)
schema.add_field(field_name="chunk_embedding", datatype=DataType.FLOAT_VECTOR, dim=512)

index_params = MilvusClient.prepare_index_params()
index_params.add_index(
    field_name="chunk_embedding",
    metric_type="COSINE",
    index_type="AUTOINDEX",
    index_name="vector_index"
)
index_params.add_index(
    field_name="chunk_sparse_embedding",
    index_name="sparse_inverted_index",
    index_type="SPARSE_INVERTED_INDEX",
    metric_type="IP",
    params={"drop_ratio_build": 0.2},
)


client.create_collection(
    collection_name="rag_demo", 
    schema=schema, 
    index_params=index_params
)

In [7]:
from pymilvus.model.sparse.bm25.tokenizers import build_default_analyzer
from pymilvus.model.sparse import BM25EmbeddingFunction
 
analyzer = build_default_analyzer(language="zh") 
bm25_ef = BM25EmbeddingFunction(analyzer)
bm25_ef.fit(news)

Building prefix dict from the default dictionary ...
Loading model from cache C:\Users\DONGZH~1\AppData\Local\Temp\jieba.cache
Loading model cost 0.360 seconds.
Prefix dict has been built successfully.


In [8]:
for title in tqdm(news):
    encode = model.encode(title)
    result = bm25_ef.encode_queries([title]).reshape(1, -1)
    
    res = client.insert(
        collection_name="rag_demo",
        data=[
            {
                "chunk_content": title,
                "chunk_sparse_embedding": {index: value for index, value in zip(result.indices, result.data)},
                "chunk_embedding": encode
            }
        ]
    )

100%|██████████| 489/489 [00:47<00:00, 10.40it/s]


In [13]:
from pymilvus import (
    utility,
    FieldSchema, CollectionSchema, DataType,
    Collection, AnnSearchRequest, RRFRanker, connections,
)
from pymilvus import WeightedRanker

query = "最近南京相关的新闻"
result = bm25_ef.encode_queries([query]).reshape(1, -1)

sparse_search_params = {"metric_type": "IP"}
sparse_req = AnnSearchRequest([{index: float(value) for index, value in zip(result.indices, result.data)}],
                              "chunk_sparse_embedding", sparse_search_params, limit=10)

dense_search_params = {"metric_type": "COSINE"}
dense_req = AnnSearchRequest([list(model.encode(query))],
                             "chunk_embedding", dense_search_params, limit=10)

reqs = [sparse_req, dense_req]

rerank = WeightedRanker(0, 1)  
result = client.hybrid_search("rag_demo", reqs, ranker=rerank, limit=3, output_fields=["chunk_content"])

In [14]:
related_news = "\n".join([x["entity"]["chunk_content"] for x in result[0]])
prompt = f"""请对用户的提问进行回答：{query}

相关资料：{related_news}
"""
print(prompt)

请对用户的提问进行回答：最近南京相关的新闻

相关资料：几天后江苏将反超广东！江苏两市喜提地铁
网友晚间继续供稿 北京夜空被盛大绚丽烟火照亮
江苏野外闲置20辆公交车 有“骗补”可能

