In [None]:
# CPU required for this notebooke: used v2-8 TPU

In [1]:
!pip install -q sentence-transformers faiss-cpu transformers datasets detoxify

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m470.2/470.2 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.3/31.3 MB[0m [31m51.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m494.8/494.8 kB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.8/194.8 kB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
from google.colab import drive
drive.mount('/content/drive')
rag_dir = '/content/drive/My Drive/kilt_rag_data'

Mounted at /content/drive


In [4]:
import json
import os
import numpy as np
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from datasets import load_dataset
from detoxify import Detoxify

In [5]:
embeddings_dir = os.path.join(rag_dir, "embeddings_v2")
os.makedirs(embeddings_dir, exist_ok=True)

faiss_dir = os.path.join(rag_dir, "faiss_index")
os.makedirs(faiss_dir, exist_ok=True)

In [None]:
#Load and convert all npz embeddings to 32f from 16f
def load_embedding_32f():
  all_embeddings = []

  npz_files = sorted([f for f in os.listdir(embeddings_dir) if f.endswith("_fp16.npz")])
  print(npz_files)
  for file in tqdm(npz_files, desc="Loading .npz chunks"):
    data = np.load(os.path.join(embeddings_dir, file))["arr_0"]
    all_embeddings.append(data.astype(np.float32))  # FAISS requires float32

  all_embeddings = np.vstack(all_embeddings)
  print(f"Loaded total embeddings: {all_embeddings.shape}")

  return all_embeddings


In [6]:
#Load and convert all npz embeddings to 32f from 16f
def load_embedding_32f():
  all_embeddings = []

  # npz_files = sorted([f for f in os.listdir(embeddings_dir) if f.endswith(".npz")])
  # print(npz_files)
  # for file in tqdm(npz_files, desc="Loading .npz chunks"):
  for chunk_id in tqdm(range(60), desc="Streaming "):
    file = f"kilt_embeddings_chunk_{chunk_id}.npz"
    data = np.load(os.path.join(embeddings_dir, file))["arr_0"]
    all_embeddings.append(data)  # FAISS requires float32 and these are 32

  all_embeddings = np.vstack(all_embeddings)
  print(f"Loaded total embeddings: {all_embeddings.shape}")

  return all_embeddings


In [7]:
all_embedings = load_embedding_32f()

Streaming : 100%|██████████| 60/60 [04:08<00:00,  4.15s/it]


Loaded total embeddings: (5099120, 384)


In [8]:
# Config
embedding_dim = 384         # Dim of your model
m = 64                      # Number of subquantizers (divides dim)
nbits = 8                   # Bits per subquantizer
nlist = 100                 # Number of IVF clusters (tune for large data)

In [9]:
import faiss

In [10]:
#Build IVF+PQ Index
quantizer = faiss.IndexFlatL2(embedding_dim)  # Used to initialize IVF
index = faiss.IndexIVFPQ(quantizer, embedding_dim, nlist, m, nbits)



In [11]:
print("Training index ...")
index.train(all_embedings)
print("Training done.")
index.add(all_embedings)
print(f"Added {index.ntotal} vectors to index.")

Training index ...
Training done.
Added 5099120 vectors to index.


In [12]:
#Save the index
faiss.write_index(index, f"{faiss_dir}/kilt_ivfpq_full_512_32B.index")
print(f"Index saved to {faiss_dir}")



Index saved to /content/drive/My Drive/kilt_rag_data/faiss_index


In [None]:
index

<faiss.swigfaiss_avx512.IndexIVFPQ; proxy of <Swig Object of type 'faiss::IndexIVFPQ *' at 0x7ac4848cff60> >