In [27]:
import os
import warnings
import random
import faiss
import pickle
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
from typing import List, Tuple, Optional, Dict
from transformers import PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
warnings.filterwarnings('ignore')

In [22]:
# Load environment variables
load_dotenv(override=True)

# General variables
seed = int(os.getenv("SEED"))
batch_size = int(os.getenv("BATCH_SIZE"))
save_every = int(os.getenv("SAVE_EVERY"))

# Embeddings variables
corpus_path = os.getenv("CORPUS_PATH")
max_length_encoder = int(os.getenv("MAX_LENGTH_ENCODER"))
normalize_embeddings = os.getenv("NORMALIZE_EMBEDDINGS") == "True"
lower_case = os.getenv("LOWER_CASE") == "True"
normalize_text = os.getenv("NORMALIZE_TEXT") == "True"
embeddings_dir = os.getenv("EMBEDDINGS_DIR")

# Indexing variables
vector_size = int(os.getenv("VECTOR_SIZE"))
faiss_dir = os.getenv("FAISS_DIR")

# Retrieval variables
top_k = int(os.getenv("TOP_K"))
use_gpu = os.getenv("USE_GPU") == "True"
gpu_ids = [int(gpu_id) for gpu_id in os.getenv("GPU_IDS").split(",")]
use_test_set = os.getenv("USE_TEST_SET") == "True"
index_batch_size = int(os.getenv("INDEX_BATCH_SIZE"))
search_dir = os.getenv("SEARCH_DIR")

In [16]:
split = "test" if use_test_set else "train"

if use_gpu and (gpu_ids is None or len(gpu_ids) == 0):
    raise Exception('gpu_ids must be set when se_gpu is used.')

split_paths = {
    "train": {
        "data_path": '../data/train.json',
    },
    "dev": {
        "data_path": '../data/dev.json',
    },
    "test": {
        "data_path": '../data/test.json',
    }
}

In [17]:
# Utility functions
def set_seeds(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

def read_pickle(file_path: str):
    with open(file_path, "rb") as reader:
        data = pickle.load(reader)
    return data


def write_pickle(data, file_path: str):
    with open(file_path, "wb") as writer:
        pickle.dump(data, writer)

In [7]:
# Run seeder before proceeding
set_seeds(seed)

In [11]:
def load_queries() -> List[str]:
    df = pd.read_json(split_paths[split]['data_path'])
    queries = df['query'].tolist() if 'query' in df.columns else df['question'].tolist()
    return queries

In [None]:
queries = load_queries()

In [18]:
class Indexer(object):
    def __init__(self, vector_size: int):     
        self.index = faiss.IndexFlatIP(vector_size)
        self.index_id_to_db_id = []

    def search_knn(
        self, 
        query_vectors: np.array, 
        top_docs: int, 
        index_batch_size: int = 2048
    ) -> List[Tuple[List[str], List[float]]]:
        """
        Performs a k-nearest neighbor search for the given query vectors.

        Args:
            query_vectors (np.array): A numpy array of query vectors.
            top_docs (int): The number of top documents to return for each query.
            index_batch_size (int): The batch size to use when indexing.
        
        Returns:
            A list of tuples, each containing a list of document IDs and a list of corresponding scores.
        """
        result = []
        nbatch = (len(query_vectors)-1) // index_batch_size + 1
        for k in range(nbatch):
            start_idx = k*index_batch_size
            end_idx = min((k+1)*index_batch_size, len(query_vectors))
            q = query_vectors[start_idx: end_idx]
            scores, indexes = self.index.search(q, top_docs)
            # convert to external ids
            db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
            result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))])
        return result


    def deserialize_from(
        self, 
        dir_path: str, 
        index_file_name: Optional[str] = None, 
        meta_file_name: Optional[str] = None,
        gpu_id: Optional[int] = None
    ):
        """
        Loads the index and its metadata from disk.

        Args:
            dir_path (str): The directory path from where to load the index and metadata.
            index_file_name (Optional[str]): Optional custom name for the index file.
            meta_file_name (Optional[str]): Optional custom name for the metadata file.
        """
        if index_file_name is None:
            index_file_name = 'index.faiss'
        if meta_file_name is None:
            meta_file_name = 'index_meta.faiss'

        index_file = os.path.join(dir_path, index_file_name)
        meta_file = os.path.join(dir_path, meta_file_name)
        print(f'Loading index from {index_file}, meta data from {meta_file}')

        self.index = faiss.read_index(index_file)
        print(f'Loaded index of type {type(self.index)} and size {self.index.ntotal}')

        self.index_id_to_db_id = read_pickle(meta_file)
        assert len(
            self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
        
        # Move index to GPU if specified
        if gpu_id is not None:
            res = faiss.StandardGpuResources()  
            self.index_gpu = faiss.index_cpu_to_gpu(res, gpu_id , self.index)
            del self.index
            self.index = self.index_gpu
            print(f'Moved index to GPU {gpu_id}')

In [20]:
def initialize_index() -> List[Indexer]:
    """Initialize and deserialize FAISS indexes."""
    indexes = []
    if use_gpu:
        for i, gpu_id in enumerate(gpu_ids):
            index = Indexer(vector_size)
            index.deserialize_from(
                faiss_dir, 
                f'index{i+1}.faiss', f'index{i+1}_meta.faiss',
                gpu_id=gpu_id
            )
            indexes.append(index)
    else: # CPU
        index = Indexer(vector_size)
        index.deserialize_from(faiss_dir)
        indexes.append(index)
    return indexes

In [None]:
indexes = initialize_index()

In [24]:
class Encoder(PreTrainedModel):
    """
    A wrapper class for encoding text using pre-trained transformer models with specified pooling strategy.
    """
    def __init__(self, config: AutoConfig, pooling: str = "average"):
        super().__init__(config)
        self.config = config
        if not hasattr(self.config, "pooling"):
            self.config.pooling = pooling

        self.model = AutoModel.from_pretrained(
            config.name_or_path, config=self.config
        )


    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        return self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )
    
    def encode(
        self, 
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        normalize: bool = False
    ) -> torch.Tensor:
        model_output = self.forward(
            input_ids, 
            attention_mask,
            token_type_ids,
        )
        last_hidden = model_output["last_hidden_state"]
        last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.)

        if self.config.pooling == "average":
            emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
        elif self.config.pooling == "cls":
            emb = last_hidden[:, 0]

        if normalize:
            emb = F.normalize(emb, dim=-1)

        return emb

In [28]:
class Retriever:
    """
    A class for retrieving document embeddings using a specified encoder, using a bi-encoder approach.
    """
    def __init__(
        self,
        device: torch.device,
        tokenizer: AutoTokenizer,
        query_encoder: Encoder,
        doc_encoder: Optional[Encoder] = None,
        max_length: int = 512,
        add_special_tokens: bool = True,
        norm_query_emb: bool = False,
        norm_doc_emb: bool = False,
        lower_case: bool = False,
        do_normalize_text: bool = False,
    ):
        
        self.device = device
        self.query_encoder = query_encoder.to(device)
        self.doc_encoder = self.query_encoder if doc_encoder is None else doc_encoder.to(device)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.add_special_tokens = add_special_tokens
        self.norm_query_emb = norm_query_emb
        self.norm_doc_emb = norm_doc_emb
        self.lower_case = lower_case
        self.do_normalize_text = do_normalize_text


    def encode_queries(self, queries: List[str], batch_size: int) -> np.ndarray:
        if self.do_normalize_text:
            queries = [normalize_text.normalize(q) for q in queries]
        if self.lower_case:
            queries = [q.lower() for q in queries]

        all_embeddings = []
        nbatch = (len(queries) - 1) // batch_size + 1
        with torch.no_grad():
            for k in range(nbatch):
                start_idx = k * batch_size
                end_idx = min((k + 1) * batch_size, len(queries))

                q_inputs = self.tokenizer(
                    queries[start_idx:end_idx],
                    max_length=self.max_length,
                    padding=True,
                    truncation=True,
                    add_special_tokens=self.add_special_tokens,
                    return_tensors="pt",
                ).to(self.device)

                emb = self.query_encoder.encode(**q_inputs, normalize=self.norm_query_emb)
                all_embeddings.append(emb.cpu())

        all_embeddings = torch.cat(all_embeddings, dim=0)
        return all_embeddings

In [29]:
def initialize_retriever() -> Retriever:
    """Initialize the encoder and retriever."""
    config = AutoConfig.from_pretrained("facebook/contriever")
    encoder = Encoder(config).eval()
    tokenizer = AutoTokenizer.from_pretrained("facebook/contriever")
    retriever = Retriever(
        device=device, tokenizer=tokenizer, 
        query_encoder=encoder, 
        max_length=max_length_encoder,
        norm_query_emb=normalize_embeddings,
        lower_case=lower_case,
        do_normalize_text=normalize_text
    )

    return retriever

In [None]:
def process_queries(retriever: Retriever, queries: List[str], batch_size: int) -> np.ndarray:
    """Encode queries using the retriever."""
    return retriever.encode_queries(queries, batch_size=batch_size).numpy()

In [None]:
retriever = initialize_retriever()
query_embeddings = process_queries(retriever, queries, batch_size)

In [30]:
def merge_ip_search_results(
    indexer1: Indexer, 
    indexer2: Indexer, 
    query_vectors: np.array, 
    top_docs: int, 
    index_batch_size: int = 2048
) -> List[Tuple[List[str], List[float]]]:
    """
    Merges the k-nearest neighbor search results from two different indices for a given set of query vectors.

    Args:
        indexer1 (Indexer): The first indexer object capable of performing knn searches.
        indexer2 (Indexer): The second indexer object capable of performing knn searches.
        query_vectors (np.array): A numpy array of query vectors for which to perform the searches.
        top_docs (int): The number of top documents to retrieve from the combined results of the two indexer.
        index_batch_size (int): The batch size to use for indexing operations.
    
    Returns:
        A list of tuples, where each tuple contains two lists - the merged list of database IDs and the corresponding scores.
    """
    # Perform searches on both indices
    results1 = indexer1.search_knn(query_vectors, top_docs, index_batch_size)
    results2 = indexer2.search_knn(query_vectors, top_docs, index_batch_size)

    merged_results = []
    for res1, res2 in zip(results1, results2):
        # Merge the results from both indices
        combined_db_ids = res1[0] + res2[0]
        combined_scores = res1[1] + res2[1]

        # Since we're using inner product, higher scores indicate better matches
        # Combine and sort the results by score in descending order
        combined = sorted(zip(combined_db_ids, combined_scores), key=lambda x: x[1], reverse=True)

        # Get only the top_docs results after merging
        combined = combined[:top_docs]

        # Separate the db_ids and scores again
        db_ids, scores = zip(*combined)

        merged_results.append((list(db_ids), list(scores)))

    return merged_results

In [31]:
def search_documents(
    indexes: List[Indexer], 
    query_embeddings: np.ndarray, 
) -> List[Tuple[List[str], List[float]]]:
    """Search documents using the indexes."""
    if use_gpu:
        search_results = merge_ip_search_results(
            indexes[0], indexes[1], query_embeddings, 
            top_docs=top_k, 
            index_batch_size=index_batch_size
        )
    else:
        search_results = indexes[0].search_knn(
            query_embeddings, top_docs=top_k, 
            index_batch_size=index_batch_size
        )
    return search_results

In [None]:
def save_search_results(
    search_results: List[Tuple[List[str], List[float]]], 
):        
    """Save search results to a pickle file."""
    os.makedirs(search_dir, exist_ok=True)
    file_path = os.path.join(
        search_dir, f'{split}_search_results_at{top_k}.pkl'
    )
    write_pickle(search_results, file_path)

In [None]:
search_results = search_documents(indexes, query_embeddings)
save_search_results(search_results)