In [14]:
#coding:utf-8
import time
import numpy as np 
import faiss

""" 
1: 向量128维，构造200万个向量，服从正态分布
"""
d = 128         
n_data = 2000000  
np.random.seed(0) 


data = []
mu = 3
sigma = 0.1

for i in range(n_data):
    data.append(np.random.normal(mu, sigma, d))

""" 向量为float32类型，即32bits（位数），4bytes（字节）"""
data = np.array(data).astype('float32')

In [15]:
"""
2: 构造10个查询向量
"""

query = []
n_query = 10

for i in range(n_query):
    query.append(np.random.normal(mu, sigma, d))
    
query = np.array(query).astype('float32')

In [14]:
"""
3: 构建索引
"""
import faiss

""" 将向量切分为8个子向量，每个向量16维，
所以m必须能被d整除"""
m = 8

""" 每个子向量集合都聚成100类 """
nlist = 100          

""" 取top10相似向量 """
k = 10

""" 构建量化器 """
start = time.time()
quantizer = faiss.IndexFlatL2(d) 

""" 每个子向量从32bits编码为8bits，
如果希望加大压缩比率，也可以编码为4bits """
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
                    
index.train(data)
index.add(data)

""" 个人理解是计算查询向量和100个聚类中心的距离，然后取距离最小的10个类别，
然后再去这10个类别中找top10相似，类似于beam search，
因为nprobe=100时，时间消耗等同于没有聚类。
"""
index.nprobe = 10
end = time.time()
print("Time usage for train: {} seconds.".format(round((end - start),4)))

Time usage for train: 11.4703 seconds.


In [15]:
"""
4: 近似最近邻搜索
"""

""" 查询验证 """
start = time.time()
dis, ind = index.search(data[:10], k) 
end = time.time()

print(ind)

print("\nTime usage for search: {} seconds.\n".format(round((end - start),4)))

[[      0 1353351  768649  401902  946967  760506 1917035 1362211  107858
    61533]
 [      1 1500400  202228 1810408  180601  895364 1005966  579157  155041
  1033148]
 [      2  909495  454642 1206165  662479 1766244 1227949  823773  176253
   607382]
 [      3 1018879 1449759 1137078 1136978  434316  702742  372995  688347
  1791706]
 [      4  779645   36929  829729 1728439 1578693  707809   79552 1671465
  1666195]
 [      5  824342 1880572 1475852  813766 1347376 1221624  803533 1217967
  1350254]
 [      6  201762   48639 1050127 1418157 1892496  661704  909349  941782
    88028]
 [      7  359775   47062  847471 1653170 1915279  739487  931792 1146174
   102686]
 [      8  754021  899503  118132 1832729 1734031  931788 1691304 1415036
   814538]
 [      9  907700 1995821  762442 1589836  989390  986111 1279824 1443448
  1891330]]

Time usage for search: 0.0045 seconds.



In [20]:
""" 真实查找 """

start = time.time()
dis, ind = index.search(query, k) 
end = time.time()

print(ind)

print("\nTime usage for search: {} seconds.\n".format(round((end - start),4)))

[[1689509  278454  716957   78299  292142  244317 1924412 1803720  783697
  1395005]
 [1874442  983923  523909  241569 1409387 1865962  860643  776683   13786
  1318036]
 [1609465  362099  792332    6226 1672637 1906965  574172 1975183  802235
  1140470]
 [ 751006  274723  383712  366339 1243287  977750 1164491  647238  991833
   699911]
 [1503916  362288  990821 1767503    7044  373711  979861 1838789 1904345
   176972]
 [  80562 1347368  821227    3651  262921  800985  623091 1828868 1319990
    51973]
 [ 554628 1719245 1837069  354601 1801095 1593832  741443  495958 1289137
  1788412]
 [1714384 1721760  964018 1440900 1278580 1344095 1487214  523925 1044892
  1316870]
 [ 434704  419990 1690755 1967662  798329 1767825  997462 1227532  962835
  1732273]
 [ 511363  122297   14457  298357 1517231  693771 1647417 1449401 1839261
   937778]]

Time usage for search: 0.0103 seconds.



In [17]:
"""
5: 保存索引和加载索引
"""
faiss.write_index(index, "index_IVFPQ.index")
index = faiss.read_index("index_IVFPQ.index")

In [27]:
"""
6: 自定义向量id
"""

""" id必须为int类型 """
ids = np.arange(2000000,4000000)

quantizer = faiss.IndexFlatL2(d) 
index_ = faiss.IndexIVFPQ(quantizer, d, nlist, m, 8)
                    
index_.train(data)

""" 自定义id """
index_.add_with_ids(data,ids)

dis, ind = index_.search(query, k) 

print(ind)

[[2716957 3803720 2246093 2621352 3062877 3778670 2595236 3351922 3617902
  3436506]
 [2241569 3849570 3948718 2072639 3158534 2310042 2010213 2903404 3747512
  3012725]
 [2362099 3750539 2724013 2930538 3800018 2276698 3466860 2969315 3056714
  3799970]
 [2751006 2366339 2977750 3517271 3856838 2117451 3983297 3018759 3057012
  2233511]
 [3503916 2362288 2176972 3254191 3536227 3739178 2797112 3866389 3164465
  2646507]
 [3347368 2003651 2623091 2051973 3245546 3375281 2963343 2518915 2760463
  2551578]
 [3593832 3024460 3186823 3535837 3926399 3208263 3176438 2474239 2236536
  2136431]
 [3714384 3440900 3344095 3044892 2044594 3669190 2862588 3027700 2101044
  2161926]
 [2434704 2419990 3673720 2587669 3810408 2192304 2402803 3571180 2330133
  2351613]
 [2511363 2014457 2937778 2748755 2579653 3229210 2684650 2214652 2512157
  3691442]]


In [29]:
type(ids[0])

numpy.int64

In [18]:
"""
7：索引工厂
"""


index_fac = faiss.index_factory(d,"OPQ16_64,IMI2x8,PQ16") 

start = time.time()
index_fac.train(data)
index_fac.add(data)
end = time.time()
print("\nTime usage for train: {} seconds.\n".format(round((end - start),4)))

start = time.time()
dis, ind = index_fac.search(data[:10], 10)  # 真实查询
end = time.time()

print(ind)
print("\nTime usage for search: {} seconds.\n".format(round((end - start),4)))


Time usage for train: 80.6588 seconds.

[[      0  939914 1160160  641444 1641268 1127837 1382182 1093474  899888
  1612053]
 [      1  268413  785710 1202730  979573  539834 1300504  197580 1712463
  1785438]
 [      2 1518155 1476130  571038  315617 1797226  254239 1448407 1659195
  1655313]
 [      3  943021 1004456 1104984  183340  211493  207600 1549732  883171
  1569879]
 [      4 1570087 1178419  806223 1701559  941105  349088 1968788 1931221
   292066]
 [      5  237800 1358549 1140185   55181  539609 1842772 1171596  340766
  1136310]
 [      6 1882569  701754 1714820 1437393  646467 1244156 1755316   16174
  1585437]
 [      7  492819 1883297  933850  633843  564510   61987 1694862 1975458
  1212472]
 [      8  867448  720355 1026777  283955 1125265  214737 1231826  641251
   643058]
 [      9 1135121 1051951 1766581 1847726 1911032  654494  575636 1125150
   187051]]

Time usage for search: 0.0004 seconds.



In [19]:
faiss.write_index(index_fac, "index_OPQ16_64,IMI2x8,PQ16.index")

In [21]:
"""
8：精确查找作为baseline，进行对比
"""

index_exact = faiss.IndexFlatL2(d)  

""" 不需要训练 """
index_exact.add(data)

start = time.time()
dis, ind = index_exact.search(data[:10], 10)
end = time.time()

print(ind)
print("\nTime usage for search: {} seconds.\n".format(round((end - start),4)))

[[      0  342266  423764  346234  681189  916533  975052  788138 1001087
  1802084]
 [      1  161229  991976 1304660 1309207  961456  884171  344711 1032968
   830166]
 [      2 1978832  122574  128882  141577 1480724  458154 1578044  912623
   334919]
 [      3 1521806 1817963  430657  179778 1450209 1756022 1037252 1041013
   300457]
 [      4 1666195 1995493  431622  622919 1842401  711426  245512  151672
  1351962]
 [      5  658751  406518  476064  621529   67393  670780 1647785  873970
   749287]
 [      6   43327 1882020 1628440 1143866  512111  708542 1036341 1765963
   966288]
 [      7  591397 1899843  534510  664256 1688601 1019285 1747174 1161384
   464800]
 [      8  338530 1797500 1095226  336240 1129628  117584   17118  771715
  1301225]
 [      9 1344145  989390 1393315  905645  552578  421387 1272224  129364
  1153549]]

Time usage for search: 0.5881 seconds.

