# Notebook Initialization

In [11]:
!pip install datasets
!pip install sentence_transformers
!pip install faiss-cpu
!pip install faiss-gpu
!pip install scann
!pip install rapidfuzz
!pip install python-Levenshtein
!pip install rank-bm25
!pip install spacy
!python -m spacy download en_core_web_sm

[31mERROR: Could not find a version that satisfies the requirement faiss-gpu (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for faiss-gpu[0m[31m
Collecting en-core-web-sm==3.7.1
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m106.0 MB/s[0m eta [36m0:00:00[0m
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [12]:
from google.colab import drive
from typing import Any, Callable, Iterable
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from rapidfuzz.process import cdist
from rank_bm25 import BM25Okapi
import os
import time
import re
import json
import copy
import pickle
import enum
import torch
import faiss
import scann
import spacy
import Levenshtein
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [13]:
drive.mount('/content/drive')

DATASET_ROOT = '/content/drive/MyDrive/ADSP Project/datasets/'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
RETRIEVAL_CAPACITY = 100

if not os.path.exists(DATASET_ROOT):
    raise ValueError('Invalid data root')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Classes

## Config

In [15]:
class Config:

    class DATASET_NAMES(enum.Enum):
        MS_MARCO = 'ms-marco'
        HOTPOT_QA = 'hotpot-qa'

    class TRANSFORMER_MODEL_NAMES(enum.Enum):
        ALL_MPNET_BASE_V2 = 'all-mpnet-base-v2'
        MULTI_QA_MPNET_BASE_DOT_V1 = 'multi-qa-mpnet-base-dot-v1'
        ALL_DISTILROBERTA_V1 = 'all-distilroberta-v1'

    class VECTOR_DB_NAMES(enum.Enum):
        FAISS = 'faiss'
        SCANN = 'scann'

    class SIMILARITY_METRIC_NAMES(enum.Enum):
        L2 = 'l2'
        IP = 'ip'
        CS = 'cs'

    class SCORER_NAMES(enum.Enum):
        SIMPLE = 'simple'
        LEVENSHTEIN = 'levenshtein'
        JARO_WINKLER = 'jaro_winkler'

    class AGGREGATOR_NAMES(enum.Enum):
        MIN_MIN_0 = 'min-min-0'
        MIN_AVG_0 = 'min-avg-0'
        MIN_MAX_0 = 'min-max-0'
        AVG_MIN_0 = 'avg-min-0'
        AVG_AVG_0 = 'avg-avg-0'
        AVG_MAX_0 = 'avg-max-0'
        MAX_MIN_0 = 'max-min-0'
        MAX_AVG_0 = 'max-avg-0'
        MAX_MAX_0 = 'max-max-0'
        MIN_MIN_1 = 'min-min-1'
        MIN_AVG_1 = 'min-avg-1'
        MIN_MAX_1 = 'min-max-1'
        AVG_MIN_1 = 'avg-min-1'
        AVG_AVG_1 = 'avg-avg-1'
        AVG_MAX_1 = 'avg-max-1'
        MAX_MIN_1 = 'max-min-1'
        MAX_AVG_1 = 'max-avg-1'
        MAX_MAX_1 = 'max-max-1'

    class TOKENIZER_NAMES(enum.Enum):
        SIMPLE = 'simple'
        LEMMA = 'lemmatization'

## Dataset

In [16]:
class Dataset:

    def __init__(self, file_name:str) -> None:
        self._file_name:str = file_name
        self._stat_dict = {
            'passages': dict[str, int](),
            'queries': dict[str, int](),
            'augmentations': dict[str, int](),
            'relations': dict[str, int](),
            'learning': dict[str, int]()
        }
        self.dataset_name:Config.DATASET_NAMES = None
        self.passage_list = list[str]()
        self.query_list = list[str]()
        self.passage_augmentation_list = list[dict[str, dict[str, int]]]()
        self.query_augmentation_list = list[dict[str, dict[str, int]]]()
        self.augmentation_dict = dict[str, set[int]]()
        self.relation_list = list[set[int]]()
        self.train_set = set[int]()
        self.validation_set = set[int]()
        self.test_set = set[int]()
        potential_dataset_path = os.path.join(DATASET_ROOT, f'{file_name}.pickle')
        if os.path.exists(potential_dataset_path):
            with open(potential_dataset_path, 'rb') as file_handle:
                public_dataset = pickle.load(file_handle)
                for attribute in public_dataset:
                    setattr(self, attribute, public_dataset[attribute])
            if self.dataset_name not in {item.value for item in Config.DATASET_NAMES}:
                raise ValueError('Invalid dataset name')
            self.dataset_name = Config.DATASET_NAMES(self.dataset_name)
        else:
            raise ValueError('Invalid file name')
        self._update_stat()

    def __str__(self) -> str:
        output_list = [f'names -> file: {self._file_name}, dataset: {self.dataset_name}']
        for stat in self._stat_dict:
            if len(self._stat_dict[stat]) == 0:
                continue
            output_list.append(f'{stat} -> ' + ', '.join(f'{attribute}: {self._stat_dict[stat][attribute]}' for attribute in self._stat_dict[stat]))
        return '\n'.join(output_list)

    def _update_stat(self) -> None:
        def __count_quantity(key:str, suffix:str, target_list:list[Any]) -> None:
            self._stat_dict[key][f'total_{suffix}'] = len(target_list)
        def __compute_stat(key:str, suffix:str, target_list:list[Iterable]) -> None:
            if len(target_list) > 0:
                self._stat_dict[key][f'minimum_{suffix}'] = min(len(iterable) for iterable in target_list)
                self._stat_dict[key][f'average_{suffix}'] = round(sum(len(iterable) for iterable in target_list) / len(target_list))
                self._stat_dict[key][f'maximum_{suffix}'] = max(len(iterable) for iterable in target_list)
        __count_quantity('passages', '', self.passage_list)
        __compute_stat('passages', 'length', self.passage_list)
        __count_quantity('queries', '', self.query_list)
        __compute_stat('queries', 'length', self.query_list)
        for augmentation_name in self.augmentation_dict:
            __count_quantity('augmentations', f'queries_augmented_with_{augmentation_name}', self.augmentation_dict[augmentation_name])
        __compute_stat('relations', 'related_passages', self.relation_list)
        __count_quantity('learning', 'queries_in_train_set', self.train_set)
        __count_quantity('learning', 'queries_in_validation_set', self.validation_set)
        __count_quantity('learning', 'queries_in_test_set', self.test_set)

    def _get_recall(self, query_index:int, query_retrieved_passage_indices:np.ndarray) -> list[float]:
        total_related_passages = len(self.relation_list[query_index])
        recall_list = list[float]()
        for k in range(1, query_retrieved_passage_indices.size + 1):
            recall_list.append(len(self.relation_list[query_index].intersection(query_retrieved_passage_indices[:k])) / total_related_passages)
        return recall_list

    def _get_optimistic_mrr(self, query_index:int, query_retrieved_passage_indices:np.ndarray) -> float:
        optimistic_mrr = 0.0
        for rank, retrieved_passage_index in enumerate(query_retrieved_passage_indices, start=1):
            if retrieved_passage_index in self.relation_list[query_index]:
                optimistic_mrr = 1.0 / rank
                break
        return optimistic_mrr

    def _get_pessimistic_mrr(self, query_index:int, query_retrieved_passage_indices:np.ndarray) -> float:
        total_related_passages = len(self.relation_list[query_index])
        pessimistic_mrr = 0.0
        for rank in range(total_related_passages, query_retrieved_passage_indices.size + 1):
            if len(self.relation_list[query_index].intersection(query_retrieved_passage_indices[:rank])) == total_related_passages:
                pessimistic_mrr = total_related_passages / rank
                break
        return pessimistic_mrr

    def get_metrics(self, query_index_list:list[int], retrieved_passage_indices:np.ndarray) -> tuple[dict[int, float], dict[int, float], float, float]:
        recall_dict = dict[int, list[float]]()
        recall_star_dict = dict[int, list[float]]()
        optimistic_mrr_list = list[float]()
        pessimistic_mrr_list = list[float]()
        for i in range(len(query_index_list)):
            total_related_passages = len(self.relation_list[query_index_list[i]])
            if total_related_passages == 0:
                continue
            recall_list = self._get_recall(query_index_list[i], retrieved_passage_indices[i, :])
            optimistic_mrr = self._get_optimistic_mrr(query_index_list[i], retrieved_passage_indices[i, :])
            pessimistic_mrr = self._get_pessimistic_mrr(query_index_list[i], retrieved_passage_indices[i, :])
            for k, recall in enumerate(recall_list, start=1):
                if k not in recall_dict:
                    recall_dict[k] = list[float]()
                recall_dict[k].append(recall)
                if k == total_related_passages:
                    if total_related_passages not in recall_star_dict:
                        recall_star_dict[total_related_passages] = list[float]()
                    recall_star_dict[total_related_passages].append(recall)
            optimistic_mrr_list.append(optimistic_mrr)
            pessimistic_mrr_list.append(pessimistic_mrr)
        avg_recall_dict = {k: sum(recall_list) / len(recall_list) for k, recall_list in dict(sorted(recall_dict.items())).items()}
        avg_recall_start_dict = {total_related_passages: sum(recall_star_list) / len(recall_star_list) for total_related_passages, recall_star_list in dict(sorted(recall_star_dict.items())).items()}
        avg_optimistic_mrr = sum(optimistic_mrr_list) / len(optimistic_mrr_list)
        avg_pessimistic_mrr = sum(pessimistic_mrr_list) / len(pessimistic_mrr_list)
        return avg_recall_dict, avg_recall_start_dict, avg_optimistic_mrr, avg_pessimistic_mrr

    def print_metrics(self, query_index_list:list[int], retrieved_passage_indices:np.ndarray) -> None:
        avg_recall_dict, avg_recall_start_dict, avg_optimistic_mrr, avg_pessimistic_mrr = self.get_metrics(query_index_list, retrieved_passage_indices)
        print('Dataset Name ->', self.dataset_name)
        print('Recall ->', ' | '.join(f'{k}: {100 * avg_recall:.2f}%' for k, avg_recall in avg_recall_dict.items()))
        print('Cluster Recall ->', ' | '.join(f'{total_related_passages}: {100 * avg_recall:.2f}%' for total_related_passages, avg_recall in avg_recall_start_dict.items()))
        print('MRR ->', f'optimistic: {100 * avg_optimistic_mrr:.2f}% | pessimistic: {100 * avg_pessimistic_mrr:.2f}%')

## Transformer

In [17]:
class Transformer:

    def __init__(self, model_name:Config.TRANSFORMER_MODEL_NAMES) -> None:
        self.model_name = model_name
        self._transformer = SentenceTransformer(model_name_or_path=model_name.value, device=DEVICE)
        self._transformer.encode(['warm_up'])

    def embed(self, text_list:list[str]) -> np.ndarray:
        print('Embedding ...', end='')
        embeddings = self._transformer.encode(text_list, convert_to_numpy=True)
        print(' done')
        return embeddings

## Semantic Searcher

In [18]:
class SemanticSearcher:

    def __init__(self, vectordb_name:Config.VECTOR_DB_NAMES, similarity_metric_name:Config.SIMILARITY_METRIC_NAMES) -> None:
        self.vectordb_name = vectordb_name
        self.similarity_metric_name = similarity_metric_name
        if self.vectordb_name == Config.VECTOR_DB_NAMES.FAISS:
            if self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.L2:
                self._engine = faiss.IndexFlatL2(768)
            elif self.similarity_metric_name in [Config.SIMILARITY_METRIC_NAMES.IP, Config.SIMILARITY_METRIC_NAMES.CS]:
                self._engine = faiss.IndexFlatIP(768)
        elif self.vectordb_name == Config.VECTOR_DB_NAMES.SCANN:
            self._engine:Any = None

    def index(self, passage_embeddings:np.ndarray) -> None:
        if self.vectordb_name == Config.VECTOR_DB_NAMES.FAISS:
            if self.similarity_metric_name in [Config.SIMILARITY_METRIC_NAMES.L2, Config.SIMILARITY_METRIC_NAMES.IP]:
                self._engine.add(passage_embeddings)
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.CS:
                normalized_passage_embeddings = passage_embeddings / np.linalg.norm(passage_embeddings, axis=1, keepdims=True)
                self._engine.add(normalized_passage_embeddings)
        elif self.vectordb_name == Config.VECTOR_DB_NAMES.SCANN:
            if self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.L2:
                self._engine = scann.scann_ops_pybind.builder(passage_embeddings, RETRIEVAL_CAPACITY, 'squared_l2').score_brute_force().build()
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.IP:
                self._engine = scann.scann_ops_pybind.builder(passage_embeddings, RETRIEVAL_CAPACITY, 'dot_product').score_brute_force().build()
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.CS:
                normalized_passage_embeddings = passage_embeddings / np.linalg.norm(passage_embeddings, axis=1, keepdims=True)
                self._engine = scann.scann_ops_pybind.builder(normalized_passage_embeddings, RETRIEVAL_CAPACITY, 'dot_product').score_brute_force().build()

    def search(self, query_embeddings:np.ndarray) -> np.ndarray:
        if self.vectordb_name == Config.VECTOR_DB_NAMES.FAISS:
            if self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.L2:
                retrieved_passage_indices = self._engine.search(query_embeddings, RETRIEVAL_CAPACITY)[1]
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.IP:
                retrieved_passage_indices = self._engine.search(query_embeddings, RETRIEVAL_CAPACITY)[1]
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.CS:
                normalized_query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
                retrieved_passage_indices = self._engine.search(normalized_query_embeddings, RETRIEVAL_CAPACITY)[1]
        elif self.vectordb_name == Config.VECTOR_DB_NAMES.SCANN:
            retrieved_passage_index_matrix = list[list[int]]()
            if self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.L2:
                for i in range(query_embeddings.shape[0]):
                    retrieved_passage_index_list = self._engine.search(query_embeddings[i, :])[0]
                    retrieved_passage_index_matrix.append(retrieved_passage_index_list)
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.IP:
                for i in range(query_embeddings.shape[0]):
                    retrieved_passage_index_list = self._engine.search(query_embeddings[i, :])[0]
                    retrieved_passage_index_matrix.append(retrieved_passage_index_list)
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.CS:
                normalized_query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
                for i in range(normalized_query_embeddings.shape[0]):
                    retrieved_passage_index_list = self._engine.search(normalized_query_embeddings[i, :])[0]
                    retrieved_passage_index_matrix.append(retrieved_passage_index_list)
            retrieved_passage_indices = np.array(retrieved_passage_index_matrix)
        return retrieved_passage_indices

## Syntactic Searcher

In [19]:
class SyntacticSearcher:

    def __init__(self, scorer_name:Config.SCORER_NAMES, aggregator_name:Config.AGGREGATOR_NAMES) -> None:
        self.scorer_name = scorer_name
        self.aggregator_name = aggregator_name
        if scorer_name == Config.SCORER_NAMES.SIMPLE:
            self._scorer:Callable[[str, str, float], float] = lambda query_entity, passage_entity, score_cutoff : float(query_entity == passage_entity)
        elif scorer_name == Config.SCORER_NAMES.LEVENSHTEIN:
            self._scorer:Callable[[str, str, float], float] = Levenshtein.distance
        elif scorer_name == Config.SCORER_NAMES.JARO_WINKLER:
            self._scorer:Callable[[str, str, float], float] = Levenshtein.jaro_winkler
        if aggregator_name == Config.AGGREGATOR_NAMES.MIN_MIN_0:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).min(dim=0)[0].min()
        elif aggregator_name == Config.AGGREGATOR_NAMES.MIN_AVG_0:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).min(dim=0)[0].mean()
        elif aggregator_name == Config.AGGREGATOR_NAMES.MIN_MAX_0:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).min(dim=0)[0].max()
        elif aggregator_name == Config.AGGREGATOR_NAMES.AVG_MIN_0:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).mean(dim=0).min()
        elif aggregator_name == Config.AGGREGATOR_NAMES.AVG_AVG_0:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).mean(dim=0).mean()
        elif aggregator_name == Config.AGGREGATOR_NAMES.AVG_MAX_0:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).mean(dim=0).max()
        elif aggregator_name == Config.AGGREGATOR_NAMES.MAX_MIN_0:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).max(dim=0)[0].min()
        elif aggregator_name == Config.AGGREGATOR_NAMES.MAX_AVG_0:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).max(dim=0)[0].mean()
        elif aggregator_name == Config.AGGREGATOR_NAMES.MAX_MAX_0:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).max(dim=0)[0].max()
        elif aggregator_name == Config.AGGREGATOR_NAMES.MIN_MIN_1:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).min(dim=1)[0].min()
        elif aggregator_name == Config.AGGREGATOR_NAMES.MIN_AVG_1:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).min(dim=1)[0].mean()
        elif aggregator_name == Config.AGGREGATOR_NAMES.MIN_MAX_1:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).min(dim=1)[0].max()
        elif aggregator_name == Config.AGGREGATOR_NAMES.AVG_MIN_1:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).mean(dim=1).min()
        elif aggregator_name == Config.AGGREGATOR_NAMES.AVG_AVG_1:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).mean(dim=1).mean()
        elif aggregator_name == Config.AGGREGATOR_NAMES.AVG_MAX_1:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).mean(dim=1).max()
        elif aggregator_name == Config.AGGREGATOR_NAMES.MAX_MIN_1:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).max(dim=1)[0].min()
        elif aggregator_name == Config.AGGREGATOR_NAMES.MAX_AVG_1:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).max(dim=1)[0].mean()
        elif aggregator_name == Config.AGGREGATOR_NAMES.MAX_MAX_1:
            self._aggregator:Callable[[np.ndarray], torch.Tensor] = lambda entity_similarities: torch.tensor(entity_similarities, device=DEVICE).max(dim=1)[0].max()
        self._passage_augmentation_list:list[dict[str, dict[str, int]]] = None

    def index(self, passage_augmentation_list:list[dict[str, dict[str, int]]]) -> None:
        self._passage_augmentation_list = passage_augmentation_list

    def search(self, query_augmentation_list:list[dict[str, dict[str, int]]]) -> tuple[np.ndarray, np.ndarray]:
        distances = torch.ones((len(query_augmentation_list), len(self._passage_augmentation_list)), device=DEVICE)
        with tqdm(total=len(query_augmentation_list) * len(self._passage_augmentation_list), desc='Computing Syntactic Similarities') as pbar:
            for i in range(len(query_augmentation_list)):
                for j in range(len(self._passage_augmentation_list)):
                    common_key_set = query_augmentation_list[i].keys() & self._passage_augmentation_list[j].keys()
                    if len(common_key_set) > 0:
                        similarities = torch.zeros(len(common_key_set), device=DEVICE)
                        for k, key in enumerate(common_key_set):
                            query_entity_array = np.array(list(query_augmentation_list[i][key].keys()))
                            query_frequency_array = np.array(list(query_augmentation_list[i][key].values()))
                            passage_entity_array = np.array(list(self._passage_augmentation_list[j][key].keys()))
                            passage_frequency_array = np.array(list(self._passage_augmentation_list[j][key].values()))
                            entity_similarities = cdist(query_entity_array, passage_entity_array, scorer=self._scorer)
                            entity_similarities = entity_similarities * query_frequency_array.reshape(-1, 1) / query_frequency_array.max()
                            entity_similarities = entity_similarities * passage_frequency_array.reshape(1, -1) / passage_frequency_array.max()
                            similarities[k] = self._aggregator(entity_similarities)
                        distances[i, j] = 1.0 - similarities.mean() * len(common_key_set) / len(query_augmentation_list[i])
                    pbar.update(1)
        print()
        retrieved_passage_indices = torch.topk(distances, k=min(RETRIEVAL_CAPACITY, distances.shape[1]), dim=1, largest=False).indices
        return retrieved_passage_indices.cpu().numpy()

## BM25 Searcher

In [20]:
class BM25Searcher:

    def __init__(self, tokenizer_name:Config.TOKENIZER_NAMES) -> None:
        spacy.prefer_gpu()
        self.tokenizer_name = tokenizer_name
        self._nlp = spacy.load('en_core_web_sm')
        if tokenizer_name == Config.TOKENIZER_NAMES.SIMPLE:
            self._tokenizer:Callable[[str], list[str]] = lambda text: text.split(' ')
        elif tokenizer_name == Config.TOKENIZER_NAMES.LEMMA:
            self._tokenizer:Callable[[str], list[str]] = lambda text: [token.lemma_ for token in self._nlp(text)]
        self._engine:BM25Okapi = None
        self._passage_list:list[str] = None

    def index(self, passage_list:list[str]) -> None:
        tokenized_passage_list = list[list[str]]()
        with tqdm(total=len(passage_list), desc='Tokenizing Passages') as pbar:
            for passage in passage_list:
                tokenized_passage_list.append(self._tokenizer(passage))
                pbar.update(1)
        print()
        self._engine = BM25Okapi(tokenized_passage_list)
        self._passage_list = passage_list

    def search(self, query_list:list[str]) -> tuple[np.ndarray, np.ndarray]:
        tokenized_query_list = list[list[str]]()
        with tqdm(total=len(query_list), desc='Tokenizing Queries') as pbar:
            for query in query_list:
                tokenized_query_list.append(self._tokenizer(query))
                pbar.update(1)
        print()
        distances = torch.ones((len(query_list), len(self._passage_list)), device=DEVICE)
        with tqdm(total=len(tokenized_query_list), desc='Computing BM25 Scores') as pbar:
            for i in range(len(query_list)):
                distances[i, :] = 1.0 - torch.tensor(self._engine.get_scores(tokenized_query_list[i]), device=DEVICE)
                pbar.update(1)
        print()
        retrieved_passage_indices = torch.topk(distances, k=min(RETRIEVAL_CAPACITY, distances.shape[1]), dim=1, largest=False).indices
        return retrieved_passage_indices.cpu().numpy()

# Operation

## Baselines

In [21]:
transformer = Transformer(Config.TRANSFORMER_MODEL_NAMES.ALL_MPNET_BASE_V2)
ms_marco_dataset = Dataset('ms-marco-spacy-llama8b')
hotpot_qa_dataset = Dataset('hotpot-qa-spacy-llama8b')
print()
print(ms_marco_dataset)
print()
print(hotpot_qa_dataset)


names -> file: ms-marco-spacy-llama8b, dataset: DATASET_NAMES.MS_MARCO
passages -> total_: 814, minimum_length: 88, average_length: 420, maximum_length: 922
queries -> total_: 100, minimum_length: 11, average_length: 34, maximum_length: 76
augmentations -> total_queries_augmented_with_spacy_ner: 100, total_queries_augmented_with_llama3-8b-8192_keyword_and_topic_extraction: 100
relations -> minimum_related_passages: 4, average_related_passages: 8, maximum_related_passages: 10
learning -> total_queries_in_train_set: 80, total_queries_in_validation_set: 10, total_queries_in_test_set: 10

names -> file: hotpot-qa-spacy-llama8b, dataset: DATASET_NAMES.HOTPOT_QA
passages -> total_: 1000, minimum_length: 80, average_length: 598, maximum_length: 8307
queries -> total_: 100, minimum_length: 40, average_length: 111, maximum_length: 418
augmentations -> total_queries_augmented_with_spacy_ner: 100, total_queries_augmented_with_llama3-8b-8192_keyword_and_topic_extraction: 38
relations -> minimum_r

In [22]:
print('B1: Syntactic Search I')
ms_marco_syntactic_searcher = SyntacticSearcher(Config.SCORER_NAMES.JARO_WINKLER, Config.AGGREGATOR_NAMES.MAX_AVG_1)
ms_marco_syntactic_searcher.index(ms_marco_dataset.passage_augmentation_list)
ms_marco_b1_retrieved_passage_indices = ms_marco_syntactic_searcher.search(ms_marco_dataset.query_augmentation_list)
ms_marco_dataset.print_metrics(list(range(len(ms_marco_dataset.query_list))), ms_marco_b1_retrieved_passage_indices)

print()
print('B2: Syntactic Search II')
ms_marco_bm25_searcher = BM25Searcher(Config.TOKENIZER_NAMES.LEMMA)
ms_marco_bm25_searcher.index(ms_marco_dataset.passage_list)
ms_marco_b2_retrieved_passage_indices = ms_marco_bm25_searcher.search(ms_marco_dataset.query_list)
ms_marco_dataset.print_metrics(list(range(len(ms_marco_dataset.query_list))), ms_marco_b2_retrieved_passage_indices)

print()
print('B3: Semantic Search')
ms_marco_semantic_searcher = SemanticSearcher(Config.VECTOR_DB_NAMES.FAISS, Config.SIMILARITY_METRIC_NAMES.CS)
ms_marco_b3_passage_embeddings = transformer.embed(ms_marco_dataset.passage_list)
ms_marco_b3_query_embeddings = transformer.embed(ms_marco_dataset.query_list)
ms_marco_semantic_searcher.index(ms_marco_b3_passage_embeddings)
ms_marco_b3_retrieved_passage_indices = ms_marco_semantic_searcher.search(ms_marco_b3_query_embeddings)
ms_marco_dataset.print_metrics(list(range(len(ms_marco_dataset.query_list))), ms_marco_b3_retrieved_passage_indices)

B1: Syntactic Search I


Computing Syntactic Similarities: 100%|██████████| 81400/81400 [00:48<00:00, 1690.94it/s]



Dataset Name -> DATASET_NAMES.MS_MARCO
Recall -> 1: 11.02% | 2: 21.94% | 3: 32.51% | 4: 41.12% | 5: 48.87% | 6: 55.12% | 7: 60.08% | 8: 63.95% | 9: 66.94% | 10: 69.10% | 11: 70.37% | 12: 72.44% | 13: 73.48% | 14: 74.58% | 15: 75.59% | 16: 76.58% | 17: 76.93% | 18: 77.60% | 19: 78.23% | 20: 78.58% | 21: 79.05% | 22: 79.45% | 23: 79.78% | 24: 79.89% | 25: 80.00% | 26: 80.12% | 27: 80.45% | 28: 80.55% | 29: 80.82% | 30: 81.06% | 31: 81.57% | 32: 81.91% | 33: 82.11% | 34: 82.47% | 35: 82.63% | 36: 82.63% | 37: 82.89% | 38: 83.00% | 39: 83.14% | 40: 83.82% | 41: 84.05% | 42: 84.37% | 43: 84.48% | 44: 84.48% | 45: 84.58% | 46: 84.58% | 47: 84.69% | 48: 84.69% | 49: 84.95% | 50: 85.37% | 51: 85.49% | 52: 85.80% | 53: 86.16% | 54: 86.33% | 55: 86.33% | 56: 86.33% | 57: 86.64% | 58: 86.64% | 59: 86.64% | 60: 86.64% | 61: 86.64% | 62: 86.78% | 63: 86.91% | 64: 86.91% | 65: 87.17% | 66: 87.34% | 67: 87.61% | 68: 87.71% | 69: 87.83% | 70: 87.83% | 71: 87.94% | 72: 88.30% | 73: 88.30% | 74: 88.30%

Tokenizing Passages: 100%|██████████| 814/814 [00:30<00:00, 26.80it/s]





Tokenizing Queries: 100%|██████████| 100/100 [00:01<00:00, 65.94it/s]





Computing BM25 Scores: 100%|██████████| 100/100 [00:00<00:00, 726.44it/s]



Dataset Name -> DATASET_NAMES.MS_MARCO
Recall -> 1: 12.32% | 2: 23.39% | 3: 33.86% | 4: 43.92% | 5: 52.71% | 6: 61.04% | 7: 67.63% | 8: 71.86% | 9: 74.54% | 10: 76.11% | 11: 77.54% | 12: 78.43% | 13: 79.10% | 14: 79.86% | 15: 80.36% | 16: 80.86% | 17: 81.10% | 18: 81.44% | 19: 81.87% | 20: 81.97% | 21: 81.97% | 22: 82.07% | 23: 82.07% | 24: 82.41% | 25: 82.63% | 26: 82.74% | 27: 82.96% | 28: 83.16% | 29: 83.26% | 30: 83.39% | 31: 83.51% | 32: 83.51% | 33: 83.61% | 34: 83.61% | 35: 83.71% | 36: 83.71% | 37: 83.81% | 38: 83.81% | 39: 84.14% | 40: 84.24% | 41: 84.24% | 42: 84.38% | 43: 84.38% | 44: 84.38% | 45: 84.38% | 46: 84.48% | 47: 84.48% | 48: 84.48% | 49: 84.58% | 50: 84.58% | 51: 84.68% | 52: 84.95% | 53: 84.95% | 54: 84.95% | 55: 84.95% | 56: 84.95% | 57: 84.95% | 58: 85.21% | 59: 85.21% | 60: 85.21% | 61: 85.21% | 62: 85.34% | 63: 85.45% | 64: 85.45% | 65: 85.45% | 66: 85.45% | 67: 85.45% | 68: 85.45% | 69: 85.56% | 70: 85.67% | 71: 85.67% | 72: 85.67% | 73: 85.67% | 74: 85.67%

In [24]:
print('B1: Syntactic Search I')
hotpot_qa_syntactic_searcher = SyntacticSearcher(Config.SCORER_NAMES.JARO_WINKLER, Config.AGGREGATOR_NAMES.MAX_AVG_1)
hotpot_qa_syntactic_searcher.index(hotpot_qa_dataset.passage_augmentation_list)
hotpot_qa_b1_retrieved_passage_indices = hotpot_qa_syntactic_searcher.search(hotpot_qa_dataset.query_augmentation_list)
hotpot_qa_dataset.print_metrics(list(range(len(hotpot_qa_dataset.query_list))), hotpot_qa_b1_retrieved_passage_indices)

print()
print('B2: Syntactic Search II')
hotpot_qa_bm25_searcher = BM25Searcher(Config.TOKENIZER_NAMES.LEMMA)
hotpot_qa_bm25_searcher.index(hotpot_qa_dataset.passage_list)
hotpot_qa_b2_retrieved_passage_indices = hotpot_qa_bm25_searcher.search(hotpot_qa_dataset.query_list)
hotpot_qa_dataset.print_metrics(list(range(len(hotpot_qa_dataset.query_list))), hotpot_qa_b2_retrieved_passage_indices)

print()
print('B3: Semantic Search')
hotpot_qa_semantic_searcher = SemanticSearcher(Config.VECTOR_DB_NAMES.FAISS, Config.SIMILARITY_METRIC_NAMES.CS)
hotpot_qa_b3_passage_embeddings = transformer.embed(hotpot_qa_dataset.passage_list)
hotpot_qa_b3_query_embeddings = transformer.embed(hotpot_qa_dataset.query_list)
hotpot_qa_semantic_searcher.index(hotpot_qa_b3_passage_embeddings)
hotpot_qa_b3_retrieved_passage_indices = hotpot_qa_semantic_searcher.search(hotpot_qa_b3_query_embeddings)
hotpot_qa_dataset.print_metrics(list(range(len(hotpot_qa_dataset.query_list))), hotpot_qa_b3_retrieved_passage_indices)

B1: Syntactic Search I


Computing Syntactic Similarities: 100%|██████████| 100000/100000 [00:46<00:00, 2145.74it/s]



Dataset Name -> DATASET_NAMES.HOTPOT_QA
Recall -> 1: 8.00% | 2: 13.40% | 3: 18.10% | 4: 21.70% | 5: 25.10% | 6: 28.10% | 7: 30.40% | 8: 32.30% | 9: 34.00% | 10: 35.70% | 11: 36.50% | 12: 37.10% | 13: 38.00% | 14: 38.60% | 15: 39.50% | 16: 40.40% | 17: 41.20% | 18: 41.70% | 19: 42.40% | 20: 43.10% | 21: 43.70% | 22: 43.90% | 23: 44.50% | 24: 44.60% | 25: 45.20% | 26: 45.40% | 27: 45.60% | 28: 45.70% | 29: 46.20% | 30: 46.50% | 31: 46.80% | 32: 46.90% | 33: 47.20% | 34: 47.30% | 35: 47.60% | 36: 47.90% | 37: 48.00% | 38: 48.50% | 39: 48.80% | 40: 49.20% | 41: 49.50% | 42: 49.90% | 43: 50.10% | 44: 50.50% | 45: 50.60% | 46: 50.70% | 47: 51.00% | 48: 51.40% | 49: 51.40% | 50: 51.50% | 51: 51.50% | 52: 51.70% | 53: 52.00% | 54: 52.00% | 55: 52.60% | 56: 52.90% | 57: 53.10% | 58: 53.60% | 59: 54.00% | 60: 54.50% | 61: 54.60% | 62: 54.60% | 63: 54.70% | 64: 54.90% | 65: 55.20% | 66: 55.50% | 67: 55.60% | 68: 55.90% | 69: 55.90% | 70: 56.00% | 71: 56.00% | 72: 56.10% | 73: 56.30% | 74: 56.40%

Tokenizing Passages: 100%|██████████| 1000/1000 [00:28<00:00, 35.07it/s]





Tokenizing Queries: 100%|██████████| 100/100 [00:01<00:00, 51.67it/s]





Computing BM25 Scores: 100%|██████████| 100/100 [00:00<00:00, 191.83it/s]



Dataset Name -> DATASET_NAMES.HOTPOT_QA
Recall -> 1: 10.00% | 2: 19.80% | 3: 29.20% | 4: 38.30% | 5: 47.60% | 6: 55.80% | 7: 63.70% | 8: 71.30% | 9: 78.30% | 10: 83.60% | 11: 86.60% | 12: 88.20% | 13: 89.50% | 14: 90.00% | 15: 91.00% | 16: 91.60% | 17: 92.20% | 18: 92.50% | 19: 93.10% | 20: 93.80% | 21: 94.20% | 22: 94.30% | 23: 94.40% | 24: 94.60% | 25: 94.70% | 26: 94.90% | 27: 95.10% | 28: 95.30% | 29: 95.50% | 30: 95.60% | 31: 95.70% | 32: 95.80% | 33: 95.80% | 34: 95.90% | 35: 95.90% | 36: 95.90% | 37: 96.10% | 38: 96.20% | 39: 96.50% | 40: 96.60% | 41: 96.60% | 42: 96.60% | 43: 96.60% | 44: 96.60% | 45: 96.60% | 46: 96.60% | 47: 96.70% | 48: 96.70% | 49: 96.80% | 50: 97.00% | 51: 97.00% | 52: 97.00% | 53: 97.00% | 54: 97.10% | 55: 97.20% | 56: 97.20% | 57: 97.20% | 58: 97.30% | 59: 97.30% | 60: 97.40% | 61: 97.40% | 62: 97.40% | 63: 97.40% | 64: 97.50% | 65: 97.50% | 66: 97.50% | 67: 97.50% | 68: 97.50% | 69: 97.50% | 70: 97.50% | 71: 97.50% | 72: 97.50% | 73: 97.50% | 74: 97.50

## Improvements

In [40]:
ms_marco_transformer = Transformer(Config.TRANSFORMER_MODEL_NAMES.ALL_MPNET_BASE_V2)
ms_marco_semantic_searcher = SemanticSearcher(Config.VECTOR_DB_NAMES.FAISS, Config.SIMILARITY_METRIC_NAMES.CS)

ms_marco_dataset = Dataset('ms-marco-no-augmentation')
print()
print(ms_marco_dataset)

ms_marco_b3_passage_embeddings = ms_marco_transformer.embed(ms_marco_dataset.passage_list)
ms_marco_b3_query_embeddings = ms_marco_transformer.embed(ms_marco_dataset.query_list)
ms_marco_semantic_searcher.index(ms_marco_b3_passage_embeddings)
ms_marco_b3_retrieved_passage_indices = ms_marco_semantic_searcher.search(ms_marco_b3_query_embeddings)

print()
print('B3: Semantic Search (MS-Marco, Train)')
ms_marco_train_query_index_list = list(range(len(ms_marco_dataset.train_set)))
ms_marco_dataset.print_metrics(ms_marco_train_query_index_list, ms_marco_b3_retrieved_passage_indices[ms_marco_train_query_index_list, :])

print()
print('B3: Semantic Search (MS-Marco, Test)')
ms_marco_test_query_index_list = list(range(len(ms_marco_dataset.test_set)))
ms_marco_dataset.print_metrics(ms_marco_test_query_index_list, ms_marco_b3_retrieved_passage_indices[ms_marco_test_query_index_list, :])


names -> file: ms-marco-no-augmentation, dataset: DATASET_NAMES.MS_MARCO
passages -> total_: 8209, minimum_length: 59, average_length: 418, maximum_length: 922
queries -> total_: 1000, minimum_length: 11, average_length: 34, maximum_length: 109
relations -> minimum_related_passages: 2, average_related_passages: 8, maximum_related_passages: 10
learning -> total_queries_in_train_set: 800, total_queries_in_validation_set: 100, total_queries_in_test_set: 100
Embedding ... done
Embedding ... done

B3: Semantic Search (MS-Marco, Train)
Dataset Name -> DATASET_NAMES.MS_MARCO
Recall -> 1: 12.65% | 2: 25.26% | 3: 37.55% | 4: 49.41% | 5: 60.48% | 6: 70.27% | 7: 78.65% | 8: 85.84% | 9: 90.50% | 10: 92.62% | 11: 93.48% | 12: 94.19% | 13: 94.66% | 14: 95.06% | 15: 95.58% | 16: 95.94% | 17: 96.23% | 18: 96.53% | 19: 96.65% | 20: 96.77% | 21: 96.98% | 22: 97.08% | 23: 97.15% | 24: 97.30% | 25: 97.33% | 26: 97.46% | 27: 97.55% | 28: 97.62% | 29: 97.73% | 30: 97.76% | 31: 97.77% | 32: 97.85% | 33: 97.

In [45]:
hotpot_qa_transformer = Transformer(Config.TRANSFORMER_MODEL_NAMES.ALL_MPNET_BASE_V2)
hotpot_qa_semantic_searcher = SemanticSearcher(Config.VECTOR_DB_NAMES.FAISS, Config.SIMILARITY_METRIC_NAMES.CS)

hotpot_qa_dataset = Dataset('hotpot-qa-no-augmentation')
print()
print(hotpot_qa_dataset)

hotpot_qa_b3_passage_embeddings = hotpot_qa_transformer.embed(hotpot_qa_dataset.passage_list)
hotpot_qa_b3_query_embeddings = hotpot_qa_transformer.embed(hotpot_qa_dataset.query_list)
hotpot_qa_semantic_searcher.index(hotpot_qa_b3_passage_embeddings)
hotpot_qa_b3_retrieved_passage_indices = hotpot_qa_semantic_searcher.search(hotpot_qa_b3_query_embeddings)

print()
print('B3: Semantic Search (Hotpot-QA, Train)')
hotpot_qa_train_query_index_list = list(range(len(hotpot_qa_dataset.train_set)))
hotpot_qa_dataset.print_metrics(hotpot_qa_train_query_index_list, hotpot_qa_b3_retrieved_passage_indices[hotpot_qa_train_query_index_list, :])

print()
print('B3: Semantic Search (Hotpot-QA, Test)')
hotpot_qa_test_query_index_list = list(range(len(hotpot_qa_dataset.test_set)))
hotpot_qa_dataset.print_metrics(hotpot_qa_test_query_index_list, hotpot_qa_b3_retrieved_passage_indices[hotpot_qa_test_query_index_list, :])


names -> file: hotpot-qa-no-augmentation, dataset: DATASET_NAMES.HOTPOT_QA
passages -> total_: 9913, minimum_length: 63, average_length: 567, maximum_length: 8307
queries -> total_: 1000, minimum_length: 32, average_length: 104, maximum_length: 542
relations -> minimum_related_passages: 1, average_related_passages: 10, maximum_related_passages: 10
learning -> total_queries_in_train_set: 800, total_queries_in_validation_set: 100, total_queries_in_test_set: 100
Embedding ... done
Embedding ... done

B3: Semantic Search (Hotpot-QA, Train)
Dataset Name -> DATASET_NAMES.HOTPOT_QA
Recall -> 1: 9.27% | 2: 15.97% | 3: 21.22% | 4: 25.67% | 5: 29.54% | 6: 32.91% | 7: 35.62% | 8: 38.14% | 9: 40.51% | 10: 42.35% | 11: 43.89% | 12: 45.44% | 13: 46.86% | 14: 47.91% | 15: 48.92% | 16: 49.89% | 17: 50.68% | 18: 51.46% | 19: 52.13% | 20: 52.90% | 21: 53.54% | 22: 54.22% | 23: 54.86% | 24: 55.48% | 25: 55.98% | 26: 56.41% | 27: 56.94% | 28: 57.37% | 29: 57.72% | 30: 58.12% | 31: 58.52% | 32: 58.86% | 3

## Class and Functions

In [None]:
class Mapper(torch.nn.Module):

    def __init__(self) -> None:
        super(Mapper, self).__init__()
        self.linear = torch.nn.Linear(768, 768)
        self.reset()

    def reset(self) -> None:
        with torch.no_grad():
            self.linear.weight.data = torch.eye(768).to(device=DEVICE)
            self.linear.bias.zero_()

    def forward(self, batch_query_embeddings:torch.Tensor) -> torch.Tensor:
        batch_mapped_query_embeddings = self.linear(batch_query_embeddings)
        return batch_mapped_query_embeddings

def get_positive_indices(dataset:Dataset, query_index:int, query_baseline_retrieved_passage_indices:np.array, total:int, mode:str) -> list[int]:
    if mode == 'random':
        positive_index_list = random.sample(list(dataset.relation_list[query_index]), total)
    else:
        if mode == 'worst-worst' or mode == 'worst-best':
            query_baseline_retrieved_passage_indices = np.flipud(query_baseline_retrieved_passage_indices)
        positive_index_list = list[int]()
        for passage_index in query_baseline_retrieved_passage_indices:
            if len(positive_index_list) == total:
                break
            if passage_index in dataset.relation_list[query_index]:
                positive_index_list.append(passage_index)
    return positive_index_list

def get_negative_indices(dataset:Dataset, query_index:int, query_baseline_retrieved_passage_indices:np.array, total:int, mode:str) -> list[int]:
    if mode == 'random':
        negative_index_list = list[int]()
        while len(negative_index_list) < total:
            negative_index = random.choice(list(dataset.train_set))
            while negative_index in dataset.relation_list[query_index]:
                negative_index = random.choice(list(dataset.train_set))
            negative_index_list.append(negative_index)
    else:
        if mode == 'worst-worst' or mode == 'best-worst':
            query_baseline_retrieved_passage_indices = np.flipud(query_baseline_retrieved_passage_indices)
        negative_index_list = list[int]()
        for passage_index in query_baseline_retrieved_passage_indices:
            if len(negative_index_list) == total:
                break
            if passage_index not in dataset.relation_list[query_index]:
                negative_index_list.append(passage_index)
    return negative_index_list

def get_targets(dataset:Dataset, passage_embeddings:np.array, baseline_retrieved_passage_indices:np.array, batch_query_index_list:list[int], preferred_total:int, positive_tendency:float, mode:str) -> tuple[torch.Tensor, torch.Tensor]:
    batch_total_positives_list = list[int]()
    batch_total_negatives_list = list[int]()
    for query_index in batch_query_index_list:
        total_positives = preferred_total * positive_tendency
        total_negatives = preferred_total - total_positives
        total_positives_error = max(1.0, total_positives / len(dataset.relation_list[query_index]))
        total_positives = round(total_positives / total_positives_error)
        total_negatives = round(total_negatives / total_positives_error)
        batch_total_positives_list.append(total_positives)
        batch_total_negatives_list.append(total_negatives)
    batch_positive_embeddings = torch.full((len(batch_query_index_list), max(batch_total_positives_list), 768), float('nan'), device=DEVICE)
    batch_negative_embeddings = torch.full((len(batch_query_index_list), max(batch_total_negatives_list), 768), float('nan'), device=DEVICE)
    for i, (query_index, total_positives, total_negatives) in enumerate(zip(batch_query_index_list, batch_total_positives_list, batch_total_negatives_list)):
        positive_index_list = get_positive_indices(dataset, query_index, baseline_retrieved_passage_indices[query_index, :], total_positives, mode)
        negative_index_list = get_negative_indices(dataset, query_index, baseline_retrieved_passage_indices[query_index, :], total_negatives, mode)
        batch_positive_embeddings[i, :len(positive_index_list), :] = torch.from_numpy(passage_embeddings[positive_index_list, :]).to(device=DEVICE)
        batch_negative_embeddings[i, :len(negative_index_list), :] = torch.from_numpy(passage_embeddings[negative_index_list, :]).to(device=DEVICE)
    return batch_positive_embeddings, batch_negative_embeddings

def get_loss(batch_mapped_query_embeddings:torch.Tensor, batch_positive_embeddings:torch.Tensor, batch_negative_embeddings:torch.Tensor, margin:float, norm_order:int) -> torch.Tensor:
    batch_aggregated_positive_embeddings = torch.nanmean(batch_positive_embeddings, dim=1)
    batch_aggregated_negative_embeddings = torch.nanmean(batch_negative_embeddings, dim=1)
    batch_positive_scores = torch.norm(batch_mapped_query_embeddings - batch_aggregated_positive_embeddings, p=norm_order, dim=1)
    batch_negative_scores = torch.norm(batch_mapped_query_embeddings - batch_aggregated_negative_embeddings, p=norm_order, dim=1)
    positive_loss = torch.nanmean(torch.abs(batch_positive_scores)**2)
    negative_loss = torch.nanmean(torch.relu(margin - batch_negative_scores)**2)
    loss = None
    if not torch.isnan(positive_loss) and not torch.isnan(negative_loss):
        loss = positive_loss + negative_loss
    elif not torch.isnan(positive_loss):
        loss = positive_loss
    elif not torch.isnan(negative_loss):
        loss = negative_loss
    return loss

## Development

In [None]:
total_epochs = 50
patience = 3

batch_size = 512
learning_rate = 0.0006598
preferred_total = 2
positive_tendency = 0.75
mode = 'worst-worst'
margin = 0.2068
norm_order = 3

mapper = Mapper().to(device=DEVICE)
optimizer = torch.optim.Adam(mapper.parameters(), lr=learning_rate)

best_mapper = mapper
best_validation_avg_pessimistic_mrr = -float('inf')
total_epochs_since_improvement = 0
for epoch in range(total_epochs):

    with tqdm(total=len(ms_marco_dataset.train_set) // batch_size, desc=f'Epoch {epoch + 1:02}/{total_epochs}') as pbar:

        mapper.train()
        loss_list = list[float]()
        for step in range(len(ms_marco_dataset.train_set) // batch_size):

            batch_query_index_list = random.sample(list(ms_marco_dataset.train_set), batch_size)
            batch_query_embeddings = torch.from_numpy(ms_marco_b3_query_embeddings[batch_query_index_list, :]).to(device=DEVICE)
            batch_mapped_query_embeddings = mapper(batch_query_embeddings)

            batch_positive_embeddings, batch_negative_embeddings = get_targets(ms_marco_dataset, ms_marco_b3_passage_embeddings, ms_marco_b3_retrieved_passage_indices, batch_query_index_list, preferred_total, positive_tendency, mode)
            loss = get_loss(batch_mapped_query_embeddings, batch_positive_embeddings, batch_negative_embeddings, margin, norm_order)
            if loss is not None:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_list.append(loss.item())

            pbar.set_postfix_str(f'loss: {np.mean(loss_list):.4f}', refresh=False)
            pbar.update(1)

        avg_loss = np.mean(loss_list)

        mapper.eval()
        with torch.no_grad():

            sample_train_query_index_list = random.sample(list(ms_marco_dataset.train_set), batch_size)
            sample_train_query_embeddings = torch.from_numpy(ms_marco_b3_query_embeddings[sample_train_query_index_list, :]).to(device=DEVICE)
            sample_train_mapped_query_embeddings = mapper(sample_train_query_embeddings)
            sample_train_retrieved_passage_indices = ms_marco_semantic_searcher.search(sample_train_mapped_query_embeddings.cpu().numpy())
            _, _, avg_train_optimistic_mrr, avg_train_pessimistic_mrr = ms_marco_dataset.get_metrics(sample_train_query_index_list, sample_train_retrieved_passage_indices)

            validation_query_index_list = list(ms_marco_dataset.validation_set)
            validation_query_embeddings = torch.from_numpy(ms_marco_b3_query_embeddings[validation_query_index_list, :]).to(device=DEVICE)
            validation_mapped_query_embeddings = mapper(validation_query_embeddings)
            validation_retrieved_passage_indices = ms_marco_semantic_searcher.search(validation_mapped_query_embeddings.cpu().numpy())
            _, _, avg_validation_optimistic_mrr, avg_validation_pessimistic_mrr = ms_marco_dataset.get_metrics(validation_query_index_list, validation_retrieved_passage_indices)

            pbar.set_postfix_str(f'train (loss: {avg_loss:.4f}, o-mrr: {avg_train_optimistic_mrr:.4f}, p-mrr: {avg_train_pessimistic_mrr:.4f}), validation (o-mrr: {avg_validation_optimistic_mrr:.4f}, p-mrr: {avg_validation_pessimistic_mrr:.4f})', refresh=True)

    if avg_validation_pessimistic_mrr > best_validation_avg_pessimistic_mrr:
        best_mapper = copy.deepcopy(mapper)
        best_validation_avg_pessimistic_mrr = avg_validation_pessimistic_mrr
        total_epochs_since_improvement = 0
    else:
        total_epochs_since_improvement += 1
    if total_epochs_since_improvement >= patience:
        break

Epoch 01/50: 100%|██████████| 1/1 [00:00<00:00,  1.17it/s, train (loss: 0.1150, o-mrr: 0.9860, p-mrr: 0.7948), validation (o-mrr: 1.0000, p-mrr: 0.8023)]
Epoch 02/50: 100%|██████████| 1/1 [00:00<00:00,  1.78it/s, train (loss: 0.1064, o-mrr: 0.9872, p-mrr: 0.8114), validation (o-mrr: 1.0000, p-mrr: 0.8083)]
Epoch 03/50: 100%|██████████| 1/1 [00:00<00:00,  1.95it/s, train (loss: 0.1000, o-mrr: 0.9873, p-mrr: 0.8287), validation (o-mrr: 1.0000, p-mrr: 0.8142)]
Epoch 04/50: 100%|██████████| 1/1 [00:00<00:00,  1.99it/s, train (loss: 0.0946, o-mrr: 0.9875, p-mrr: 0.8391), validation (o-mrr: 1.0000, p-mrr: 0.8157)]
Epoch 05/50: 100%|██████████| 1/1 [00:00<00:00,  1.99it/s, train (loss: 0.0903, o-mrr: 0.9882, p-mrr: 0.8514), validation (o-mrr: 1.0000, p-mrr: 0.8147)]
Epoch 06/50: 100%|██████████| 1/1 [00:00<00:00,  1.96it/s, train (loss: 0.0864, o-mrr: 0.9891, p-mrr: 0.8542), validation (o-mrr: 1.0000, p-mrr: 0.8146)]
Epoch 07/50: 100%|██████████| 1/1 [00:00<00:00,  1.99it/s, train (loss: 0.08


test (o-mrr: 0.9883 | p-mrr: 0.8359)





## Evaluation

In [None]:
mapper.eval()
with torch.no_grad():
    test_query_index_list = list(ms_marco_dataset.test_set)
    test_query_embeddings = torch.from_numpy(ms_marco_b3_query_embeddings[test_query_index_list, :]).to(device=DEVICE)
    test_mapped_query_embeddings = mapper(test_query_embeddings)
    test_retrieved_passage_indices = ms_marco_semantic_searcher.search(test_mapped_query_embeddings.cpu().numpy())
    print('Proposed Method: Semantic Search (MS-Marco, Test)')
    ms_marco_dataset.print_metrics(test_query_index_list, test_retrieved_passage_indices)

In [None]:
mapper.eval()
with torch.no_grad():
    test_query_index_list = list(hotpot_qa_dataset.test_set)
    test_query_embeddings = torch.from_numpy(hotpot_qa_b3_query_embeddings[test_query_index_list, :]).to(device=DEVICE)
    test_mapped_query_embeddings = mapper(test_query_embeddings)
    test_retrieved_passage_indices = hotpot_qa_semantic_searcher.search(test_mapped_query_embeddings.cpu().numpy())
    print('Proposed Method: Semantic Search (Hotpot-QA, Test)')
    hotpot_qa_dataset.print_metrics(test_query_index_list, test_retrieved_passage_indices)