In [8]:
import faiss
faiss.__version__

'1.8.0'

In [9]:
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import numpy as np
import torch
import json
import itertools

In [10]:
class RecursiveNamespace:
    def __init__(self):
        self.__dict__['_attributes'] = {}

    def __getattr__(self, name):
        if name not in self._attributes:
            self._attributes[name] = RecursiveNamespace()
        return self._attributes[name]

    def __setattr__(self, name, value):
        if isinstance(value, RecursiveNamespace):
            self._attributes[name] = value
        else:
            current = self
            *parts, last = name.split('.')
            for part in parts:
                current = getattr(current, part)
            current._attributes[last] = value

    def __repr__(self):
        return repr(self._attributes)
    
    @staticmethod
    def from_dict(d):
        ns = RecursiveNamespace()
        for k, v in d.items():
            if isinstance(v, dict):
                ns._attributes[k] = RecursiveNamespace.from_dict(v)
            else:
                ns._attributes[k] = v
        return ns

def read_fvecs(filename):
    with open(filename, 'rb') as f:
        vecs = []
        while True:
            data = f.read(4)
            if len(data) < 4:
                break
            d = int.from_bytes(data, 'little')
            vec = np.frombuffer(f.read(d * 4), dtype=np.float32)
            vecs.append(vec)
        return np.array(vecs)

def load_config(filePath):
    '''
    load configuration json to a RecursiveNamespace object
    '''
    with open(filePath, 'r') as f:
        return RecursiveNamespace.from_dict(json.loads(f.read()))

In [11]:
import faiss
import numpy as np

# 初始化数据
d = 128  # 向量维度
nt = 10000  # 训练集大小
nb = 10000  # 数据库大小
nq = 100  # 查询向量数量
np.random.seed(1234)
# db_vectors = torch.zeros(1, 32, 114438, 128)
# for j in range(1):
#     db_vectors[0, j] = torch.tensor(read_fvecs(f'../llama_key/PTB/key_0_{j}.fvecs'))
# db_vectors.shape
db_vectors = torch.tensor(read_fvecs('../llama_key/PTB/key_0_0.fvecs'))


In [12]:
db_vectors.shape

torch.Size([114438, 128])

In [13]:
d = db_vectors.shape[-1]
nb = db_vectors.shape[-2]
nt = nb
nq = 100

# db_vectors = np.random.random((nb, d)).astype('float32')
# db_vectors = np.random.normal(1.0, 1.6, (nb, d))
# xtrain = np.random.random((nt, d)).astype('float32')
# db_vectors = np.random.normal(1.0, 2.6, (nb, d))
xtrain = db_vectors
query_vectors = np.random.random((nq, d)).astype('float32')

# FAISS资源和索引配置
# res = faiss.StandardGpuResources()  # 使用标准GPU资源
index_flat = faiss.IndexFlatIP(d)  # 使用内积作为距离度量

# 创建GpuIndexIVFPQ索引
nlist = 1  # 聚类中心的数量
m = 4  # 子向量数量
nbits = 12
k = 200  # 要返回的最近邻数量
quantizer = index_flat  # 使用IndexFlatIP作为量化器
# index = faiss.index_cpu_to_gpu(res, 0, faiss.IndexIVFPQ(quantizer, d, nlist, m, 8))
# index = faiss.index_cpu_to_gpu(res, 0, faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits, faiss.METRIC_INNER_PRODUCT))
# index = faiss.IndexIVFPQ(quantizer, d, nlist, m, nbits, faiss.METRIC_INNER_PRODUCT)
index = faiss.IndexPQ(d, m, nbits, faiss.METRIC_INNER_PRODUCT)
index.train(xtrain)

# index_inner = faiss.index_factory(d, f'PQ{m}x{nbits}', faiss.METRIC_INNER_PRODUCT)
# index_inner.train(xtrain)
# index = faiss.index_factory(d, f'IVF1,PQ{m}x{nbits}', faiss.METRIC_INNER_PRODUCT)
# index.quantizer.add(np.zeros((1, d)))
# index.pq = index_inner.pq
# index.is_trained = True

index.add(db_vectors)

# 检索
D, I = index.search(query_vectors, k)  # 搜索最近的k个向量
# recons_x = index.reconstruct()
# print(I[0])
# print(D[0])




In [17]:
k=200
trace1 = go.Scatter(x=np.arange(k), y=D[0][:], name='PQ')
trace2 = go.Scatter(x=np.arange(k), y=np.dot(query_vectors[0], db_vectors[I[0]].T), name='oriIP')
fig = go.Figure([trace1, trace2])
fig.show()


In [18]:
recons_x = index.reconstruct_n(0, nb)

In [19]:
trace1 = go.Scatter(x=np.arange(k), y=D[1][:], name='PQ')
trace2 = go.Scatter(x=np.arange(k), y=np.dot(query_vectors[1], db_vectors[I[1]].T), name='oriIP')
trace3 = go.Scatter(x=np.arange(k), y=np.dot(query_vectors[1], recons_x[I[1]].T), name='reconsIP')
fig = go.Figure([trace1, trace2, trace3])
fig.show()

In [30]:
# projection = np.random.random((d, d))*8
projection = np.random.normal(20, 8, (d, d))
projection += np.random.normal(10, 0.1, (d, d))
value = np.matmul(db_vectors, projection)
value_restored = np.matmul(recons_x, projection)
value.shape

torch.Size([114438, 128])

In [31]:
s = query_vectors[1]
k = 200
trace2 = go.Scatter(x=np.arange(k), y=np.dot(s, value[:k].T), name='oriIP')
trace3 = go.Scatter(x=np.arange(k), y=np.dot(s, value_restored[:k].T), name='reconsIP')
fig = go.Figure([trace2, trace3])
fig.show()

## 真实transition matrix

In [24]:
cache, pq, t = torch.load('weired.pt')
t.shape

torch.Size([4096, 4096])

In [25]:
db_vectors.shape

torch.Size([114438, 128])

In [None]:

projection = t.detach().numpy()
value = np.matmul(db_vectors, projection)
value_restored = np.matmul(recons_x, projection)
value.shape