In [1]:
from pymilvus import MilvusClient
from pymilvus import connections, MilvusClient, DataType
from pymilvus import AnnSearchRequest

# 云端向量数据库
CLUSTER_ENDPOINT = "https://in03-9e9fdbef8863964.serverless.ali-cn-hangzhou.cloud.zilliz.com.cn" 
TOKEN = "0b26d6c6392802bee965421c29437b0777d93c49f4871a18586ab47f1f3e98b9e805a65e8aa0a49b12ce816684cd94c8d9cd4344"

client = MilvusClient(
    uri=CLUSTER_ENDPOINT,
    token=TOKEN 
)

In [2]:
client.list_collections()

['vector_demo', 'rag_demo']

In [3]:
client.drop_collection(collection_name="vector_demo")

## RAG案例

In [4]:
import pandas as pd
from tqdm import tqdm

local_file_path = './dataset/test.txt'

# 使用 pandas 读取本地文件
news = pd.read_csv(local_file_path, sep='\t', header=None)[0].drop_duplicates().values

# 打印前几行，查看数据是否正确读取
print(news[:10])  # 只打印前 10 个新闻标题



In [5]:
from sentence_transformers import SentenceTransformer
local_model_path = "./bge/bge-small-en-v1.5"

# # 在线加载模型
# from sentence_transformers import SentenceTransformer
# model = SentenceTransformer("BAAI/bge-small-zh-v1.5")

# 本地加载模型
model = SentenceTransformer(local_model_path)

In [6]:
client.drop_collection(collection_name="vector_demo")

schema = MilvusClient.create_schema(
    auto_id=True,
    enable_dynamic_field=True,
)
schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True)
schema.add_field(field_name="chunk_content", datatype=DataType.VARCHAR, max_length=1024)
schema.add_field(field_name="chunk_sparse_embedding", datatype=DataType.SPARSE_FLOAT_VECTOR)
schema.add_field(field_name="chunk_embedding", datatype=DataType.FLOAT_VECTOR, dim=512)

index_params = MilvusClient.prepare_index_params()
index_params.add_index(
    field_name="chunk_embedding",
    metric_type="COSINE",
    index_type="AUTOINDEX",
    index_name="vector_index"
)
index_params.add_index(
    field_name="chunk_sparse_embedding",
    index_name="sparse_inverted_index",
    index_type="SPARSE_INVERTED_INDEX",
    metric_type="IP",
    params={"drop_ratio_build": 0.2},
)


client.create_collection(
    collection_name="vector_demo", 
    schema=schema, 
    index_params=index_params
)

In [7]:
from pymilvus.model.sparse.bm25.tokenizers import build_default_analyzer
from pymilvus.model.sparse import BM25EmbeddingFunction
 
analyzer = build_default_analyzer(language="en") 
bm25_ef = BM25EmbeddingFunction(analyzer)
bm25_ef.fit(news)

In [8]:
import numpy as np

# 截取或者填充到 512 维
def ensure_vector_dimension(vector, target_dim=512):
    if len(vector) > target_dim:
        return vector[:target_dim]
    elif len(vector) < target_dim:
        return np.pad(vector, (0, target_dim - len(vector)), 'constant')
    return vector


for title in tqdm(news):

    encode = model.encode(title)
    encode = ensure_vector_dimension(encode, 512)

    result = bm25_ef.encode_queries([title]).reshape(1, -1)
    
    res = client.insert(
        collection_name="vector_demo",
        data=[
            {
                "chunk_content": title,
                "chunk_sparse_embedding": {index: value for index, value in zip(result.indices, result.data)},
                "chunk_embedding": encode
            }
        ]
    )

100%|██████████| 4/4 [00:01<00:00,  2.48it/s]


In [12]:
from pymilvus import (
    utility,
    FieldSchema, CollectionSchema, DataType,
    Collection, AnnSearchRequest, RRFRanker, connections,
)
from pymilvus import WeightedRanker

query = "public void testGetInstance ( ) <<<<<<< \nthrows IllegalStateException , NamingException \n=======\nthrows Exception \n>>>>>>>" 
result = bm25_ef.encode_queries([query]).reshape(1, -1)

sparse_search_params = {"metric_type": "IP"}
sparse_req = AnnSearchRequest([{index: float(value) for index, value in zip(result.indices, result.data)}],
                              "chunk_sparse_embedding", sparse_search_params, limit=2)

dense_search_params = {"metric_type": "COSINE"}
dense_req = AnnSearchRequest([ensure_vector_dimension(list(model.encode(query)), 512)],
                             "chunk_embedding", dense_search_params, limit=2)

reqs = [sparse_req, dense_req]

rerank = WeightedRanker(0.2, 0.8)  
result = client.hybrid_search("vector_demo", reqs, ranker=rerank, limit=3, output_fields=["chunk_content"])

In [13]:
related_news = "\n".join([x["entity"]["chunk_content"] for x in result[0]])
prompt = f"""查询：
{query}

检索结果：
{related_news}
"""
print(prompt)

查询：
public void testGetInstance ( ) <<<<<<< 
throws IllegalStateException , NamingException 
throws Exception 
>>>>>>>

检索结果：

