In [1]:
import os
import faiss
import tqdm
import json
import numpy as np
from utils import extract_features_single, load_pca, create_model_and_tokenizer

from datasets import load_dataset

# model_name = "sentence-transformers/all-MiniLM-L6-v2"
model_name = "facebook/contriever-msmarco"
model, tokenizer = create_model_and_tokenizer(model_name)
model = model.cuda()

query_data = load_dataset('json', data_files='./lm-eval-train.jsonl', split='train')

data = load_dataset('json', data_files='/root/data/wiki_en.jsonl', split='train')




In [2]:
vecs = np.load(f'./features/{model_name}/wiki_en.npy')

In [None]:
# new_data = []
# new_vecs = []
# for i, item in enumerate(data):
#     if item['text'].split().__len__() < 100:
#         continue
#     new_data.append(item)
#     new_vecs.append(vecs[i])

In [8]:
def build_index(
    xb: np.ndarray,
    d: int = None,
):
    if d is None:
        d = xb.shape[-1]
    res = faiss.StandardGpuResources()
    # index = faiss.index_factory(d, "IVF100,PQ8")
    # index = faiss.index_factory(d, "HNSW32")
    
    # Param of PQ
    M = 16  # The number of sub-vector. Typically this is 8, 16, 32, etc.
    nbits = 8 # bits per sub-vector. This is typically 8, so that each sub-vec is encoded by 1 byte
    # Param of IVF
    nlist = 10000  # The number of cells (space partition). Typical value is sqrt(N)
    # Param of HNSW
    hnsw_m = 32  # The number of neighbors for HNSW. This is typically 32

    # Setup
    quantizer = faiss.IndexHNSWFlat(d, hnsw_m)
    index = faiss.IndexIVFPQ(quantizer, d, nlist, M, nbits)
    
    # Sample 1_000_000 vectors to train the index.
    # xt = xb[np.random.choice(xb.shape[0], 1000_000, replace=False)]
    xt = xb
    index.train(xt)
    # index.add(xb)
    return index

def add_to_index(
    xb,
    index,
):
    index.add(xb)
    return index

In [9]:
def search(queries, k=100):
    features = extract_features_single(queries, model, tokenizer)
    _, ids = index.search(features, k=k)
    return [data[int(i)] for i in ids[0]]

In [10]:
index = build_index(vecs)

In [11]:
index = add_to_index(vecs, index)

In [21]:
# search("Where is Zurich?", k=100)

In [None]:
with open(f'/root/ft_data/lm_aug_wiki_hnsw_ivf_all.jsonl', 'w') as f:
    for query in tqdm.tqdm(query_data):
        items = search(query['text'])
        for rank, item in enumerate(items):
            item['meta']['augment'] = {
                'rank': rank,
                'query_meta': query['meta'],
            }
            item['meta'].pop('timestamp')
            f.write(json.dumps(item) + '\n')

  5%|██████▍                                                                                                                               | 3781/79195 [03:08<2:19:02,  9.04it/s]

In [None]:
exit()

In [None]:
1