[FAISS + SBERT实现的十亿级语义相似性搜索])(https://www.leiphone.com/news/202011/jRU5fk4FpyYZzLMM.html)

### 介绍

对训练数据使用sbert进行语句嵌入, 并使用faiss来索引

当一个新句子来了后, 先对其进行编码, 然后用faiss查找最相近的句子

### 安装和加载包

In [4]:
!pip install faiss-cpu
!pip install -U sentence-transformers

Requirement already up-to-date: sentence-transformers in /usr/local/lib/python3.6/dist-packages (0.3.9)


In [1]:
import numpy as np
import torch
import os
import pandas as pd
import faiss
import time
from sentence_transformers import SentenceTransformer

In [6]:
!wget https://github.com/franciscadias/data/raw/master/abcnews-date-text.csv

--2020-11-22 04:02:17--  https://github.com/franciscadias/data/raw/master/abcnews-date-text.csv
Resolving github.com (github.com)... 140.82.114.3
Connecting to github.com (github.com)|140.82.114.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/franciscadias/data/master/abcnews-date-text.csv [following]
--2020-11-22 04:02:17--  https://raw.githubusercontent.com/franciscadias/data/master/abcnews-date-text.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 54096356 (52M) [text/plain]
Saving to: ‘abcnews-date-text.csv.1’


2020-11-22 04:02:18 (245 MB/s) - ‘abcnews-date-text.csv.1’ saved [54096356/54096356]



### 加载并查看数据

In [2]:
df=pd.read_csv("abcnews-date-text.csv")
data=df.headline_text.to_list()

In [8]:
data[:10]

['aba decides against community broadcasting licence',
 'act fire witnesses must be aware of defamation',
 'a g calls for infrastructure protection summit',
 'air nz staff in aust strike for pay rise',
 'air nz strike to affect australian travellers',
 'ambitious olsson wins triple jump',
 'antic delighted with record breaking barca',
 'aussie qualifier stosur wastes four memphis match',
 'aust addresses un security council over iraq',
 'australia is locked into war timetable opp']

### 加载预训练模型并且进行推断

In [3]:
model = SentenceTransformer('distilbert-base-nli-mean-tokens', device='cuda')

In [None]:
encoded_data = model.encode(data)

In [11]:
encoded_data.shape

(1082168, 768)

### 为数据集编制索引

In [12]:
index = faiss.IndexIDMap(faiss.IndexFlatIP(768))
index.add_with_ids(encoded_data, np.array(range(0, len(data))))

#### 序列化索引(将index保存到磁盘)

In [13]:
faiss.write_index(index, 'abc_news')

#### 反序列化索引(从磁盘中读取index)

In [4]:
index = faiss.read_index('abc_news')

### 执行语义相似性搜索

In [14]:
def search(query):
  t=time.time()
  query_vector = model.encode([query])
  k = 5
  top_k = index.search(query_vector, k)
  print('totaltime: {}'.format(time.time()-t))
  return [data[_id] for _id in top_k[1].tolist()[0]]

In [16]:
results=search("test news title")

totaltime: 1.1659197807312012


In [17]:
print(results)

['test article', 'test preview', 'news quiz', 'news quiz', 'news quiz']


### GPU faiss

In [18]:
!pip uninstall faiss-cpu
!pip install faiss-gpu

Uninstalling faiss-cpu-1.6.4.post2:
  Would remove:
    /usr/local/lib/python3.6/dist-packages/faiss/*
    /usr/local/lib/python3.6/dist-packages/faiss_cpu-1.6.4.post2.dist-info/*
    /usr/local/lib/python3.6/dist-packages/faiss_cpu.libs/libgfortran-040039e1.so.5.0.0
    /usr/local/lib/python3.6/dist-packages/faiss_cpu.libs/libgomp-7c85b1e2.so.1.0.0
    /usr/local/lib/python3.6/dist-packages/faiss_cpu.libs/libquadmath-30d679e1.so.0.0.0
    /usr/local/lib/python3.6/dist-packages/faiss_cpu.libs/libz-d8a329de.so.1.2.7
Proceed (y/n)? y
  Successfully uninstalled faiss-cpu-1.6.4.post2
y
Collecting faiss-gpu
[?25l  Downloading https://files.pythonhosted.org/packages/5a/6b/1e316d731ce94821854cd54d04b4a0dd3e3c5d47292d9a373b56a8e19a8f/faiss_gpu-1.6.4.post2-cp36-cp36m-manylinux2014_x86_64.whl (67.6MB)
[K     |████████████████████████████████| 67.6MB 50kB/s 
[?25hInstalling collected packages: faiss-gpu
Successfully installed faiss-gpu-1.6.4.post2


In [5]:
res = faiss.StandardGpuResources()
gpu_index = faiss.index_cpu_to_gpu(res, 0, index)

In [6]:
def search_gpu(query):
  t=time.time()
  query_vector = model.encode([query])
  k = 5
  top_k = gpu_index.search(query_vector, k)  # only diff
  print('totaltime: {}'.format(time.time()-t))
  return [data[_id] for _id in top_k[1].tolist()[0]]

In [9]:
results=search_gpu("test news title")

totaltime: 0.056054115295410156


In [10]:
print(results)

['test article', 'test preview', 'news quiz', 'news quiz', 'news quiz']
