### 导入包

In [None]:
from pymilvus import utility
from pymilvus import connections
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
import csv
import time
from sentence_transformers import SentenceTransformer
# import gdown
# url = 'https://drive.google.com/uc?id=11ISS45aO2ubNCGaC3Lvd3D7NT8Y7MeO8'
# output = './movies.zip'
# gdown.download(url, output)
 
# import zipfile
 
# with zipfile.ZipFile("./movies.zip","r") as zip_ref:
#     zip_ref.extractall("./movies")
 

### 全局参数

在这里，我们可以找到需要修改以运行您自己的账户的主要参数。每个参数旁边都有一个描述。

In [3]:
# Milvus Setup Arguments
COLLECTION_NAME = 'movies_db'  # Collection name
DIMENSION = 384  # Embeddings size
COUNT = 1000  # Number of vectors to insert
MILVUS_HOST = 'localhost'
MILVUS_PORT = '19530'
 
# Inference Arguments
BATCH_SIZE = 128
 
# Search Arguments
TOP_K = 3
 

### 连接到 Milvus 数据库

In [4]:

 

connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
 

### 删除具有相同名称的以前集合

In [5]:
if utility.has_collection(COLLECTION_NAME):
    utility.drop_collection(COLLECTION_NAME)
 

### 创建包含 ID、标题和情节文本嵌入的集合

In [6]:


fields = [
    FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name='title', dtype=DataType.VARCHAR, max_length=200),  # VARCHAR 需要一个最大长度，所以为了这个例子，它们被设置为200个字符。
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]
schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)
 

### 为集合创建IVF_FLAT索引

In [7]:

index_params = {
    'metric_type':'L2',
    'index_type':"IVF_FLAT",
    'params':{'nlist': 1536}
}
collection.create_index(field_name="embedding", index_params=index_params)
collection.load()
 

### 模型下载

In [10]:

from modelscope import snapshot_download
model_dir = snapshot_download('bensonpeng/all-MiniLM-L6-v2',cache_dir='./models')

transformer = SentenceTransformer('./models/bensonpeng/all-MiniLM-L6-v2')

2024-04-19 01:19:07,029 - modelscope - INFO - PyTorch version 2.2.2 Found.
2024-04-19 01:19:07,031 - modelscope - INFO - Loading ast index from C:\Users\mybcg\.cache\modelscope\ast_indexer
2024-04-19 01:19:07,130 - modelscope - INFO - Loading done! Current index file version is 1.13.3, with md5 61104cf01099cdfec0b7ca5a334bcfed and a total number of 972 components indexed
Downloading: 100%|██████████| 698/698 [00:00<?, ?B/s] 
Downloading: 100%|██████████| 86.7M/86.7M [00:04<00:00, 19.5MB/s]
Downloading: 100%|██████████| 17.4k/17.4k [00:00<00:00, 3.55MB/s]
Downloading: 100%|██████████| 695/695 [00:00<00:00, 684kB/s]
Downloading: 100%|██████████| 695k/695k [00:00<00:00, 5.24MB/s]
Downloading: 100%|██████████| 1.40k/1.40k [00:00<?, ?B/s]
Downloading: 100%|██████████| 226k/226k [00:00<00:00, 5.14MB/s]


### 导入数据到向量库

In [14]:

def csv_load(file):
    with open(file, newline='',encoding='utf-8') as f:
        reader = csv.reader(f, delimiter=',')
        for row in reader:
            if '' in (row[1], row[7]):
                continue
            yield (row[1], row[7])
 
# Extract embeding from text using OpenAI
def embed_insert(data):
    embeds = transformer.encode(data[1]) 
    ins = [
            data[0],
            [x for x in embeds]
    ]
    collection.insert(ins)
 

data_batch = [[],[]]
count = 0

for title, plot in csv_load('./wiki_movie_plots_deduped.csv'):
    if count <= COUNT:
        data_batch[0].append(title)
        data_batch[1].append(plot)
        if len(data_batch[0]) % BATCH_SIZE == 0:
            embed_insert(data_batch)
            data_batch = [[],[]]
        count += 1
    else:
        break
 
# Embed and insert the remainder
if len(data_batch[0]) != 0:
    embed_insert(data_batch)
 
# Call a flush to index any unsealed segments.
collection.flush()
 

No sentence-transformers model found with name ./models/bensonpeng/all-MiniLM-L6-v2. Creating a new one with MEAN pooling.


### 将问题使用embedding转换成向量

In [29]:
# Search for titles that closest match these phrases.
search_terms = ['A movie about cars', 'A movie about monsters']
 
# Search the database based on input text
def embed_search(data):
    embeds = transformer.encode(data) 
    return [x for x in embeds]
 
search_data = embed_search(search_terms)

### 查询出结果

In [28]:

 
start = time.time()
res = collection.search(
    data=search_data,  # Embeded search value
    anns_field="embedding",  # Search across embeddings
    param={ 
                    },
    limit = TOP_K,  # Limit to top_k results per search
    output_fields=['title']  # Include title field in result
)
end = time.time()
 
for hits_i, hits in enumerate(res):
    print('Title:', search_terms[hits_i])
    print('Search Time:', end-start)
    print('Results:')
    for hit in hits:
        print( hit.entity.get('title'), '----', hit.distance)
    print()
 

Title: A movie about cars
Search Time: 0.01253509521484375
Results:
From Leadville to Aspen: A Hold-Up in the Rockies ---- 39.53708267211914
Gentlemen of Nerve ---- 39.97119903564453
Hot Water ---- 41.052330017089844

Title: A movie about monsters
Search Time: 0.01253509521484375
Results:
The Suburbanite ---- 39.47476577758789
The Shriek of Araby ---- 42.2584228515625
The Cavalier ---- 42.49919128417969

