In [12]:
import random
import os
import faiss
import pickle
import torch
import numpy as np
import warnings
from dotenv import load_dotenv
from typing import List, Tuple, Optional

warnings.filterwarnings('ignore')

In [13]:
# Load environment variables
load_dotenv(override=True)

# General variables
seed = int(os.getenv("SEED"))
batch_size = int(os.getenv("BATCH_SIZE"))
save_every = int(os.getenv("SAVE_EVERY"))

# Embedding variables
embeddings_dir = os.getenv("EMBEDDINGS_DIR")

# Indexing variables
corpus_size = int(os.getenv("CORPUS_SIZE"))
vector_size = int(os.getenv("VECTOR_SIZE"))
faiss_dir = os.getenv("FAISS_DIR")

In [14]:
# Utility functions
def set_seeds(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def read_pickle(file_path: str):
    with open(file_path, "rb") as reader:
        data = pickle.load(reader)
    return data


def write_pickle(data, file_path: str):
    with open(file_path, "wb") as writer:
        pickle.dump(data, writer)

In [15]:
# Run seeder before proceeding
set_seeds(seed)

In [4]:
# Embeddings are stored in multiple files due to batch size, this function saves all embeddings into a single file
def load_all_embeddings() -> np.array:
    all_embeddings_path = f'{embeddings_dir}/all_embeddings.npy'

    # Check if the file with all embeddings already exists and in case load it
    if os.path.isfile(all_embeddings_path):
        embeddings = np.load(all_embeddings_path, mmap_mode='c')
        return embeddings

    all_embeddings = []
    num_embed = batch_size * save_every

    for i in range(num_embed - 1, corpus_size, num_embed):
        emb_path = f'{embeddings_dir}/{i}_embeddings.npy'
        emb = np.load(emb_path, mmap_mode='c')
        all_embeddings.append(emb)

    last_idx = corpus_size - 1
    last_emb_path = f'{embeddings_dir}/{last_idx}_embeddings.npy'
    last_emb = np.load(last_emb_path, mmap_mode='c')
    all_embeddings.append(last_emb)

    embeddings = np.concatenate(all_embeddings, axis=0)
    np.save(all_embeddings_path, embeddings)

    return embeddings

In [None]:
# Load embeddings for indexer
embeddings = load_all_embeddings()

In [10]:
class Indexer(object):
    def __init__(self, vector_size: int):     
        self.index = faiss.IndexFlatIP(vector_size)
        self.index_id_to_db_id = []


    def index_data(self, ids: List[int], embeddings: np.array):
        """
        Adds data to the index.

        Args:
            ids (List[int]): A list of database IDs corresponding to the embeddings.
            embeddings (np.array): A numpy array of embeddings to be indexed.
        """
        self._update_id_mapping(ids)
        # embeddings = embeddings.astype('float32')
        if not self.index.is_trained:
            self.index.train(embeddings)
        self.index.add(embeddings)

        print(f'Total data indexed {len(self.index_id_to_db_id)}')

    def search_knn(
        self, 
        query_vectors: np.array, 
        top_docs: int, 
        index_batch_size: int = 2048
    ) -> List[Tuple[List[str], List[float]]]:
        """
        Performs a k-nearest neighbor search for the given query vectors.

        Args:
            query_vectors (np.array): A numpy array of query vectors.
            top_docs (int): The number of top documents to return for each query.
            index_batch_size (int): The batch size to use when indexing.
        
        Returns:
            A list of tuples, each containing a list of document IDs and a list of corresponding scores.
        """
        result = []
        nbatch = (len(query_vectors)-1) // index_batch_size + 1
        for k in range(nbatch):
            start_idx = k*index_batch_size
            end_idx = min((k+1)*index_batch_size, len(query_vectors))
            q = query_vectors[start_idx: end_idx]
            scores, indexes = self.index.search(q, top_docs)
            # convert to external ids
            db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
            result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))])
        return result

    def serialize(
        self, 
        dir_path: str, 
        index_file_name: Optional[str] = None, 
        meta_file_name: Optional[str] = None
    ):
        """
        Serializes the index and its metadata to disk.

        Args:
            dir_path (str): The directory path to save the serialized index and metadata.
            index_file_name (Optional[str]): Optional custom name for the index file.
            meta_file_name (Optional[str]): Optional custom name for the metadata file.
        """
        if index_file_name is None:
            index_file_name = 'index.faiss'
        if meta_file_name is None:
            meta_file_name = 'index_meta.faiss'

        index_file = os.path.join(dir_path, index_file_name)
        meta_file = os.path.join(dir_path, meta_file_name)
        print(f'Serializing index to {index_file}, meta data to {meta_file}')

        faiss.write_index(self.index, index_file)
        write_pickle(self.index_id_to_db_id, meta_file)


    def deserialize_from(
        self, 
        dir_path: str, 
        index_file_name: Optional[str] = None, 
        meta_file_name: Optional[str] = None,
        gpu_id: Optional[int] = None
    ):
        """
        Loads the index and its metadata from disk.

        Args:
            dir_path (str): The directory path from where to load the index and metadata.
            index_file_name (Optional[str]): Optional custom name for the index file.
            meta_file_name (Optional[str]): Optional custom name for the metadata file.
        """
        if index_file_name is None:
            index_file_name = 'index.faiss'
        if meta_file_name is None:
            meta_file_name = 'index_meta.faiss'

        index_file = os.path.join(dir_path, index_file_name)
        meta_file = os.path.join(dir_path, meta_file_name)
        print(f'Loading index from {index_file}, meta data from {meta_file}')

        self.index = faiss.read_index(index_file)
        print(f'Loaded index of type {type(self.index)} and size {self.index.ntotal}')

        self.index_id_to_db_id = read_pickle(meta_file)
        assert len(
            self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
        
        # Move index to GPU if specified
        if gpu_id is not None:
            res = faiss.StandardGpuResources()  
            self.index_gpu = faiss.index_cpu_to_gpu(res, gpu_id , self.index)
            del self.index
            self.index = self.index_gpu
            print(f'Moved index to GPU {gpu_id}')
        

    def _update_id_mapping(self, db_ids: List[int]):
        self.index_id_to_db_id.extend(db_ids)

    def get_index_name(self):
        return "index"

In [11]:
def indexing_embeddings(embeddings: np.array) -> None:
    os.makedirs(faiss_dir, exist_ok=True)

    index = Indexer(vector_size)
    index.index_data(list(range(corpus_size)), embeddings)

    index.serialize(
        dir_path=faiss_dir, 
        index_file_name=f'index.faiss', 
        meta_file_name=f'index_meta.faiss'
    )

In [None]:
indexing_embeddings(embeddings)