In [1]:
import random
import os
import faiss
import pickle
import torch
import numpy as np
import warnings
from dotenv import load_dotenv
from typing import List, Optional

warnings.filterwarnings('ignore')

In [2]:
# Load environment variables
load_dotenv(override=True)

# General variables
seed = int(os.getenv("SEED"))
batch_size = int(os.getenv("BATCH_SIZE"))
save_every = int(os.getenv("SAVE_EVERY"))

# Embedding variables
embeddings_dir = os.getenv("EMBEDDINGS_DIR")

# Indexing variables
corpus_size = int(os.getenv("CORPUS_SIZE"))
vector_size = int(os.getenv("VECTOR_SIZE"))
faiss_dir = os.getenv("FAISS_DIR")

In [3]:
# Utility functions
def set_seeds(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def read_pickle(file_path: str):
    with open(file_path, "rb") as reader:
        data = pickle.load(reader)
    return data


def write_pickle(data, file_path: str):
    with open(file_path, "wb") as writer:
        pickle.dump(data, writer)

In [4]:
# Run seeder before proceeding
set_seeds(seed)

In [5]:
# Embeddings are stored in multiple files due to batch size, this function saves all embeddings into a single file
def load_all_embeddings() -> np.array:
    all_embeddings_path = f'{embeddings_dir}/all_embeddings.npy'

    # Check if the file with all embeddings already exists and in case load it
    if os.path.isfile(all_embeddings_path):
        embeddings = np.load(all_embeddings_path, mmap_mode='c')
        return embeddings

    all_embeddings = []
    num_embed = batch_size * save_every

    for i in range(num_embed - 1, corpus_size, num_embed):
        emb_path = f'{embeddings_dir}/{i}_embeddings.npy'
        emb = np.load(emb_path, mmap_mode='c')
        all_embeddings.append(emb)

    last_idx = corpus_size - 1
    last_emb_path = f'{embeddings_dir}/{last_idx}_embeddings.npy'
    last_emb = np.load(last_emb_path, mmap_mode='c')
    all_embeddings.append(last_emb)

    embeddings = np.concatenate(all_embeddings, axis=0)
    np.save(all_embeddings_path, embeddings)

    return embeddings

In [6]:
# Load embeddings for indexer
embeddings = load_all_embeddings()

In [7]:
class Indexer(object):
    def __init__(self, vector_size: int):     
        self.index = faiss.IndexFlatIP(vector_size)
        self.index_id_to_db_id = []


    def index_data(self, ids: List[int], embeddings: np.array):
        """
        Adds data to the index.

        Args:
            ids (List[int]): A list of database IDs corresponding to the embeddings.
            embeddings (np.array): A numpy array of embeddings to be indexed.
        """
        self._update_id_mapping(ids)
        # embeddings = embeddings.astype('float32')
        if not self.index.is_trained:
            self.index.train(embeddings)
        self.index.add(embeddings)

        print(f'Total data indexed {len(self.index_id_to_db_id)}')

    def serialize(
        self, 
        dir_path: str, 
        index_file_name: Optional[str] = None, 
        meta_file_name: Optional[str] = None
    ):
        """
        Serializes the index and its metadata to disk.

        Args:
            dir_path (str): The directory path to save the serialized index and metadata.
            index_file_name (Optional[str]): Optional custom name for the index file.
            meta_file_name (Optional[str]): Optional custom name for the metadata file.
        """
        if index_file_name is None:
            index_file_name = 'index.faiss'
        if meta_file_name is None:
            meta_file_name = 'index_meta.faiss'

        index_file = os.path.join(dir_path, index_file_name)
        meta_file = os.path.join(dir_path, meta_file_name)
        print(f'Serializing index to {index_file}, meta data to {meta_file}')

        faiss.write_index(self.index, index_file)
        write_pickle(self.index_id_to_db_id, meta_file)
        

    def _update_id_mapping(self, db_ids: List[int]):
        self.index_id_to_db_id.extend(db_ids)

In [8]:
def indexing_embeddings(embeddings: np.array) -> None:
    os.makedirs(faiss_dir, exist_ok=True)

    index = Indexer(vector_size)
    index.index_data(list(range(corpus_size)), embeddings)

    index.serialize(
        dir_path=faiss_dir, 
        index_file_name=f'index.faiss', 
        meta_file_name=f'index_meta.faiss'
    )

In [9]:
indexing_embeddings(embeddings)

Total data indexed 563424
Serializing index to ../data/embeddings/indexes/index.faiss, meta data to ../data/embeddings/indexes/index_meta.faiss
