# Notebook Initialization

In [None]:
!pip install datasets
!pip install sentence_transformers
!pip install faiss-cpu
!pip install faiss-gpu
!pip install scann
!pip install fastapi
!pip install python-multipart
!pip install pyngrok
!pip install uvicorn
!pip install groq

In [None]:
from google.colab import drive
from typing import Any, Callable, Iterable
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from pyngrok import ngrok
from fastapi.middleware.cors import CORSMiddleware
from groq import Groq
import os
import time
import re
import json
import copy
import pickle
import enum
import torch
import faiss
import scann
import random
import numpy as np
import pandas as pd
import fastapi
import pyngrok
import asyncio
import uvicorn
import nest_asyncio

In [None]:
drive.mount('/content/drive')

DATASET_ROOT = '/content/drive/MyDrive/ADSP Project/datasets/'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
RETRIEVAL_CAPACITY = 50
GROQ_MODEL = 'llama3-8b-8192'
GROQ_API_KEY = 'YOUR API KEY FROM GROQ'
NGROK_AUTHENTICATION_TOKEN = 'YOUR AUTHENTICATION TOKEN FROM NGROK'
INTERNAL_PORT = 8002

if not os.path.exists(DATASET_ROOT):
    raise ValueError('Invalid data root')

# Classes

## Config

In [None]:
class Config:

    class DATASET_NAMES(enum.Enum):
        MS_MARCO = 'ms-marco'
        HOTPOT_QA = 'hotpot-qa'

    class TRANSFORMER_MODEL_NAMES(enum.Enum):
        ALL_MPNET_BASE_V2 = 'all-mpnet-base-v2'
        MULTI_QA_MPNET_BASE_DOT_V1 = 'multi-qa-mpnet-base-dot-v1'
        ALL_DISTILROBERTA_V1 = 'all-distilroberta-v1'

    class VECTOR_DB_NAMES(enum.Enum):
        FAISS = 'faiss'
        SCANN = 'scann'

    class SIMILARITY_METRIC_NAMES(enum.Enum):
        L2 = 'l2'
        IP = 'ip'
        CS = 'cs'

## Dataset

In [None]:
class Dataset:

    def __init__(self, file_name:str) -> None:
        self._file_name:str = file_name
        self._stat_dict = {
            'passages': dict[str, int](),
            'queries': dict[str, int](),
            'augmentations': dict[str, int](),
            'relations': dict[str, int](),
            'learning': dict[str, int]()
        }
        self.dataset_name:Config.DATASET_NAMES = None
        self.passage_list = list[str]()
        self.query_list = list[str]()
        self.passage_augmentation_list = list[dict[str, dict[str, int]]]()
        self.query_augmentation_list = list[dict[str, dict[str, int]]]()
        self.augmentation_dict = dict[str, set[int]]()
        self.relation_list = list[set[int]]()
        self.train_set = set[int]()
        self.validation_set = set[int]()
        self.test_set = set[int]()
        potential_dataset_path = os.path.join(DATASET_ROOT, f'{file_name}.pickle')
        if os.path.exists(potential_dataset_path):
            with open(potential_dataset_path, 'rb') as file_handle:
                public_dataset = pickle.load(file_handle)
                for attribute in public_dataset:
                    setattr(self, attribute, public_dataset[attribute])
            if self.dataset_name not in {item.value for item in Config.DATASET_NAMES}:
                raise ValueError('Invalid dataset name')
            self.dataset_name = Config.DATASET_NAMES(self.dataset_name)
        else:
            raise ValueError('Invalid file name')
        self._update_stat()

    def __str__(self) -> str:
        output_list = [f'names -> file: {self._file_name}, dataset: {self.dataset_name}']
        for stat in self._stat_dict:
            if len(self._stat_dict[stat]) == 0:
                continue
            output_list.append(f'{stat} -> ' + ', '.join(f'{attribute}: {self._stat_dict[stat][attribute]}' for attribute in self._stat_dict[stat]))
        return '\n'.join(output_list)

    def _update_stat(self) -> None:
        def __count_quantity(key:str, suffix:str, target_list:list[Any]) -> None:
            self._stat_dict[key][f'total_{suffix}'] = len(target_list)
        def __compute_stat(key:str, suffix:str, target_list:list[Iterable]) -> None:
            if len(target_list) > 0:
                self._stat_dict[key][f'minimum_{suffix}'] = min(len(iterable) for iterable in target_list)
                self._stat_dict[key][f'average_{suffix}'] = round(sum(len(iterable) for iterable in target_list) / len(target_list))
                self._stat_dict[key][f'maximum_{suffix}'] = max(len(iterable) for iterable in target_list)
        __count_quantity('passages', '', self.passage_list)
        __compute_stat('passages', 'length', self.passage_list)
        __count_quantity('queries', '', self.query_list)
        __compute_stat('queries', 'length', self.query_list)
        for augmentation_name in self.augmentation_dict:
            __count_quantity('augmentations', f'queries_augmented_with_{augmentation_name}', self.augmentation_dict[augmentation_name])
        __compute_stat('relations', 'related_passages', self.relation_list)
        __count_quantity('learning', 'queries_in_train_set', self.train_set)
        __count_quantity('learning', 'queries_in_validation_set', self.validation_set)
        __count_quantity('learning', 'queries_in_test_set', self.test_set)

    def _get_recall(self, query_index:int, query_retrieved_passage_indices:np.ndarray) -> list[float]:
        total_related_passages = len(self.relation_list[query_index])
        recall_list = list[float]()
        for k in range(1, query_retrieved_passage_indices.size + 1):
            recall_list.append(len(self.relation_list[query_index].intersection(query_retrieved_passage_indices[:k])) / total_related_passages)
        return recall_list

    def _get_optimistic_mrr(self, query_index:int, query_retrieved_passage_indices:np.ndarray) -> float:
        optimistic_mrr = 0.0
        for rank, retrieved_passage_index in enumerate(query_retrieved_passage_indices, start=1):
            if retrieved_passage_index in self.relation_list[query_index]:
                optimistic_mrr = 1.0 / rank
                break
        return optimistic_mrr

    def _get_pessimistic_mrr(self, query_index:int, query_retrieved_passage_indices:np.ndarray) -> float:
        total_related_passages = len(self.relation_list[query_index])
        pessimistic_mrr = 0.0
        for rank in range(total_related_passages, query_retrieved_passage_indices.size + 1):
            if len(self.relation_list[query_index].intersection(query_retrieved_passage_indices[:rank])) == total_related_passages:
                pessimistic_mrr = total_related_passages / rank
                break
        return pessimistic_mrr

    def get_metrics(self, query_index_list:list[int], retrieved_passage_indices:np.ndarray) -> tuple[dict[int, float], dict[int, float], float, float]:
        recall_dict = dict[int, list[float]]()
        recall_star_dict = dict[int, list[float]]()
        optimistic_mrr_list = list[float]()
        pessimistic_mrr_list = list[float]()
        for i in range(len(query_index_list)):
            total_related_passages = len(self.relation_list[query_index_list[i]])
            if total_related_passages == 0:
                continue
            recall_list = self._get_recall(query_index_list[i], retrieved_passage_indices[i, :])
            optimistic_mrr = self._get_optimistic_mrr(query_index_list[i], retrieved_passage_indices[i, :])
            pessimistic_mrr = self._get_pessimistic_mrr(query_index_list[i], retrieved_passage_indices[i, :])
            for k, recall in enumerate(recall_list, start=1):
                if k not in recall_dict:
                    recall_dict[k] = list[float]()
                recall_dict[k].append(recall)
                if k == total_related_passages:
                    if total_related_passages not in recall_star_dict:
                        recall_star_dict[total_related_passages] = list[float]()
                    recall_star_dict[total_related_passages].append(recall)
            optimistic_mrr_list.append(optimistic_mrr)
            pessimistic_mrr_list.append(pessimistic_mrr)
        avg_recall_dict = {k: sum(recall_list) / len(recall_list) for k, recall_list in dict(sorted(recall_dict.items())).items()}
        avg_recall_start_dict = {total_related_passages: sum(recall_star_list) / len(recall_star_list) for total_related_passages, recall_star_list in dict(sorted(recall_star_dict.items())).items()}
        avg_optimistic_mrr = sum(optimistic_mrr_list) / len(optimistic_mrr_list)
        avg_pessimistic_mrr = sum(pessimistic_mrr_list) / len(pessimistic_mrr_list)
        return avg_recall_dict, avg_recall_start_dict, avg_optimistic_mrr, avg_pessimistic_mrr

    def print_metrics(self, query_index_list:list[int], retrieved_passage_indices:np.ndarray) -> None:
        avg_recall_dict, avg_recall_start_dict, avg_optimistic_mrr, avg_pessimistic_mrr = self.get_metrics(query_index_list, retrieved_passage_indices)
        print('Dataset Name ->', self.dataset_name)
        print('Recall ->', ' | '.join(f'{k}: {100 * avg_recall:.2f}%' for k, avg_recall in avg_recall_dict.items()))
        print('Cluster Recall ->', ' | '.join(f'{total_related_passages}: {100 * avg_recall:.2f}%' for total_related_passages, avg_recall in avg_recall_start_dict.items()))
        print('MRR ->', f'optimistic: {100 * avg_optimistic_mrr:.2f}% | pessimistic: {100 * avg_pessimistic_mrr:.2f}%')

## Transformer

In [None]:
class Transformer:

    def __init__(self, model_name:Config.TRANSFORMER_MODEL_NAMES) -> None:
        self.model_name = model_name
        self._transformer = SentenceTransformer(model_name_or_path=model_name.value, device=DEVICE)
        self._transformer.encode(['warm_up'])

    def embed(self, text_list:list[str]) -> np.ndarray:
        print('Embedding ...', end='')
        embeddings = self._transformer.encode(text_list, convert_to_numpy=True)
        print(' done')
        return embeddings

## Semantic Searcher

In [None]:
class SemanticSearcher:

    def __init__(self, vectordb_name:Config.VECTOR_DB_NAMES, similarity_metric_name:Config.SIMILARITY_METRIC_NAMES) -> None:
        self.vectordb_name = vectordb_name
        self.similarity_metric_name = similarity_metric_name
        if self.vectordb_name == Config.VECTOR_DB_NAMES.FAISS:
            if self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.L2:
                self._engine = faiss.IndexFlatL2(768)
            elif self.similarity_metric_name in [Config.SIMILARITY_METRIC_NAMES.IP, Config.SIMILARITY_METRIC_NAMES.CS]:
                self._engine = faiss.IndexFlatIP(768)
        elif self.vectordb_name == Config.VECTOR_DB_NAMES.SCANN:
            self._engine:Any = None

    def index(self, passage_embeddings:np.ndarray) -> None:
        if self.vectordb_name == Config.VECTOR_DB_NAMES.FAISS:
            if self.similarity_metric_name in [Config.SIMILARITY_METRIC_NAMES.L2, Config.SIMILARITY_METRIC_NAMES.IP]:
                self._engine.add(passage_embeddings)
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.CS:
                normalized_passage_embeddings = passage_embeddings / np.linalg.norm(passage_embeddings, axis=1, keepdims=True)
                self._engine.add(normalized_passage_embeddings)
        elif self.vectordb_name == Config.VECTOR_DB_NAMES.SCANN:
            if self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.L2:
                self._engine = scann.scann_ops_pybind.builder(passage_embeddings, RETRIEVAL_CAPACITY, 'squared_l2').score_brute_force().build()
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.IP:
                self._engine = scann.scann_ops_pybind.builder(passage_embeddings, RETRIEVAL_CAPACITY, 'dot_product').score_brute_force().build()
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.CS:
                normalized_passage_embeddings = passage_embeddings / np.linalg.norm(passage_embeddings, axis=1, keepdims=True)
                self._engine = scann.scann_ops_pybind.builder(normalized_passage_embeddings, RETRIEVAL_CAPACITY, 'dot_product').score_brute_force().build()

    def search(self, query_embeddings:np.ndarray) -> np.ndarray:
        if self.vectordb_name == Config.VECTOR_DB_NAMES.FAISS:
            if self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.L2:
                retrieved_passage_indices = self._engine.search(query_embeddings, RETRIEVAL_CAPACITY)[1]
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.IP:
                retrieved_passage_indices = self._engine.search(query_embeddings, RETRIEVAL_CAPACITY)[1]
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.CS:
                normalized_query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
                retrieved_passage_indices = self._engine.search(normalized_query_embeddings, RETRIEVAL_CAPACITY)[1]
        elif self.vectordb_name == Config.VECTOR_DB_NAMES.SCANN:
            retrieved_passage_index_matrix = list[list[int]]()
            if self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.L2:
                for i in range(query_embeddings.shape[0]):
                    retrieved_passage_index_list = self._engine.search(query_embeddings[i, :])[0]
                    retrieved_passage_index_matrix.append(retrieved_passage_index_list)
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.IP:
                for i in range(query_embeddings.shape[0]):
                    retrieved_passage_index_list = self._engine.search(query_embeddings[i, :])[0]
                    retrieved_passage_index_matrix.append(retrieved_passage_index_list)
            elif self.similarity_metric_name == Config.SIMILARITY_METRIC_NAMES.CS:
                normalized_query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
                for i in range(normalized_query_embeddings.shape[0]):
                    retrieved_passage_index_list = self._engine.search(normalized_query_embeddings[i, :])[0]
                    retrieved_passage_index_matrix.append(retrieved_passage_index_list)
            retrieved_passage_indices = np.array(retrieved_passage_index_matrix)
        return retrieved_passage_indices

# Web Server

## Initialization

In [None]:
ms_marco_transformer = Transformer(Config.TRANSFORMER_MODEL_NAMES.ALL_MPNET_BASE_V2)
ms_marco_semantic_searcher = SemanticSearcher(Config.VECTOR_DB_NAMES.FAISS, Config.SIMILARITY_METRIC_NAMES.CS)

ms_marco_dataset = Dataset('ms-marco-no-augmentation')
print()
print(ms_marco_dataset)

ms_marco_b3_passage_embeddings = ms_marco_transformer.embed(ms_marco_dataset.passage_list)
ms_marco_b3_query_embeddings = ms_marco_transformer.embed(ms_marco_dataset.query_list)
ms_marco_semantic_searcher.index(ms_marco_b3_passage_embeddings)
ms_marco_b3_retrieved_passage_indices = ms_marco_semantic_searcher.search(ms_marco_b3_query_embeddings)

## Class and Functions

In [None]:
class Mapper(torch.nn.Module):

    def __init__(self) -> None:
        super(Mapper, self).__init__()
        self.linear = torch.nn.Linear(768, 768)
        self.reset()

    def reset(self) -> None:
        with torch.no_grad():
            self.linear.weight.data = torch.eye(768).to(device=DEVICE)
            self.linear.bias.zero_()

    def forward(self, batch_query_embeddings:torch.Tensor) -> torch.Tensor:
        batch_mapped_query_embeddings = self.linear(batch_query_embeddings)
        return batch_mapped_query_embeddings

def get_positive_indices(dataset:Dataset, query_index:int, query_baseline_retrieved_passage_indices:np.array, total:int, mode:str) -> list[int]:
    if mode == 'random':
        positive_index_list = random.sample(list(dataset.relation_list[query_index]), total)
    else:
        if mode == 'worst-worst' or mode == 'worst-best':
            query_baseline_retrieved_passage_indices = np.flipud(query_baseline_retrieved_passage_indices)
        positive_index_list = list[int]()
        for passage_index in query_baseline_retrieved_passage_indices:
            if len(positive_index_list) == total:
                break
            if passage_index in dataset.relation_list[query_index]:
                positive_index_list.append(passage_index)
    return positive_index_list

def get_negative_indices(dataset:Dataset, query_index:int, query_baseline_retrieved_passage_indices:np.array, total:int, mode:str) -> list[int]:
    if mode == 'random':
        negative_index_list = list[int]()
        while len(negative_index_list) < total:
            negative_index = random.choice(list(dataset.train_set))
            while negative_index in dataset.relation_list[query_index]:
                negative_index = random.choice(list(dataset.train_set))
            negative_index_list.append(negative_index)
    else:
        if mode == 'worst-worst' or mode == 'best-worst':
            query_baseline_retrieved_passage_indices = np.flipud(query_baseline_retrieved_passage_indices)
        negative_index_list = list[int]()
        for passage_index in query_baseline_retrieved_passage_indices:
            if len(negative_index_list) == total:
                break
            if passage_index not in dataset.relation_list[query_index]:
                negative_index_list.append(passage_index)
    return negative_index_list

def get_targets(dataset:Dataset, passage_embeddings:np.array, baseline_retrieved_passage_indices:np.array, batch_query_index_list:list[int], preferred_total:int, positive_tendency:float, mode:str) -> tuple[torch.Tensor, torch.Tensor]:
    batch_total_positives_list = list[int]()
    batch_total_negatives_list = list[int]()
    for query_index in batch_query_index_list:
        total_positives = preferred_total * positive_tendency
        total_negatives = preferred_total - total_positives
        total_positives_error = max(1.0, total_positives / len(dataset.relation_list[query_index]))
        total_positives = round(total_positives / total_positives_error)
        total_negatives = round(total_negatives / total_positives_error)
        batch_total_positives_list.append(total_positives)
        batch_total_negatives_list.append(total_negatives)
    batch_positive_embeddings = torch.full((len(batch_query_index_list), max(batch_total_positives_list), 768), float('nan'), device=DEVICE)
    batch_negative_embeddings = torch.full((len(batch_query_index_list), max(batch_total_negatives_list), 768), float('nan'), device=DEVICE)
    for i, (query_index, total_positives, total_negatives) in enumerate(zip(batch_query_index_list, batch_total_positives_list, batch_total_negatives_list)):
        positive_index_list = get_positive_indices(dataset, query_index, baseline_retrieved_passage_indices[query_index, :], total_positives, mode)
        negative_index_list = get_negative_indices(dataset, query_index, baseline_retrieved_passage_indices[query_index, :], total_negatives, mode)
        batch_positive_embeddings[i, :len(positive_index_list), :] = torch.from_numpy(passage_embeddings[positive_index_list, :]).to(device=DEVICE)
        batch_negative_embeddings[i, :len(negative_index_list), :] = torch.from_numpy(passage_embeddings[negative_index_list, :]).to(device=DEVICE)
    return batch_positive_embeddings, batch_negative_embeddings

def get_loss(batch_mapped_query_embeddings:torch.Tensor, batch_positive_embeddings:torch.Tensor, batch_negative_embeddings:torch.Tensor, margin:float, norm_order:int) -> torch.Tensor:
    batch_aggregated_positive_embeddings = torch.nanmean(batch_positive_embeddings, dim=1)
    batch_aggregated_negative_embeddings = torch.nanmean(batch_negative_embeddings, dim=1)
    batch_positive_scores = torch.norm(batch_mapped_query_embeddings - batch_aggregated_positive_embeddings, p=norm_order, dim=1)
    batch_negative_scores = torch.norm(batch_mapped_query_embeddings - batch_aggregated_negative_embeddings, p=norm_order, dim=1)
    positive_loss = torch.nanmean(torch.abs(batch_positive_scores)**2)
    negative_loss = torch.nanmean(torch.relu(margin - batch_negative_scores)**2)
    loss = None
    if not torch.isnan(positive_loss) and not torch.isnan(negative_loss):
        loss = positive_loss + negative_loss
    elif not torch.isnan(positive_loss):
        loss = positive_loss
    elif not torch.isnan(negative_loss):
        loss = negative_loss
    return loss

def get_llm_response(query:str, passage_list:list[str]):
    groq_client = Groq(api_key=GROQ_API_KEY)
    chat_completion = groq_client.chat.completions.create(
        messages=[
            {'role': 'system', 'content': (
                'You are an AI assistant tasked to work in a Retrieval Augmented Generation architecture.'
                '\nYou recieve a question and a list of documents.'
                '\nYou answer the question only based on the provided context.'
                '\nNever use your general knowledge in your response.'
            )},
            {'role': 'user', 'content': (
                'The following is the question:'
                '\n' + query + ''
                'And this is the document list:'
                '\n' + '\n'.join(passage_list) + ''
            )}
        ],
        model=GROQ_MODEL,
        max_tokens=512
    )
    return chat_completion.choices[0].message.content

## Development

In [None]:
total_epochs = 50
patience = 3

batch_size = 512
learning_rate = 0.0006598
preferred_total = 2
positive_tendency = 0.75
mode = 'worst-worst'
margin = 0.2068
norm_order = 3

mapper = Mapper().to(device=DEVICE)
optimizer = torch.optim.Adam(mapper.parameters(), lr=learning_rate)

best_mapper = mapper
best_validation_avg_pessimistic_mrr = -float('inf')
total_epochs_since_improvement = 0
for epoch in range(total_epochs):

    with tqdm(total=len(ms_marco_dataset.train_set) // batch_size, desc=f'Epoch {epoch + 1:02}/{total_epochs}') as pbar:

        mapper.train()
        loss_list = list[float]()
        for step in range(len(ms_marco_dataset.train_set) // batch_size):

            batch_query_index_list = random.sample(list(ms_marco_dataset.train_set), batch_size)
            batch_query_embeddings = torch.from_numpy(ms_marco_b3_query_embeddings[batch_query_index_list, :]).to(device=DEVICE)
            batch_mapped_query_embeddings = mapper(batch_query_embeddings)

            batch_positive_embeddings, batch_negative_embeddings = get_targets(ms_marco_dataset, ms_marco_b3_passage_embeddings, ms_marco_b3_retrieved_passage_indices, batch_query_index_list, preferred_total, positive_tendency, mode)
            loss = get_loss(batch_mapped_query_embeddings, batch_positive_embeddings, batch_negative_embeddings, margin, norm_order)
            if loss is not None:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_list.append(loss.item())

            pbar.set_postfix_str(f'loss: {np.mean(loss_list):.4f}', refresh=False)
            pbar.update(1)

        avg_loss = np.mean(loss_list)

        mapper.eval()
        with torch.no_grad():

            sample_train_query_index_list = random.sample(list(ms_marco_dataset.train_set), batch_size)
            sample_train_query_embeddings = torch.from_numpy(ms_marco_b3_query_embeddings[sample_train_query_index_list, :]).to(device=DEVICE)
            sample_train_mapped_query_embeddings = mapper(sample_train_query_embeddings)
            sample_train_retrieved_passage_indices = ms_marco_semantic_searcher.search(sample_train_mapped_query_embeddings.cpu().numpy())
            _, _, avg_train_optimistic_mrr, avg_train_pessimistic_mrr = ms_marco_dataset.get_metrics(sample_train_query_index_list, sample_train_retrieved_passage_indices)

            validation_query_index_list = list(ms_marco_dataset.validation_set)
            validation_query_embeddings = torch.from_numpy(ms_marco_b3_query_embeddings[validation_query_index_list, :]).to(device=DEVICE)
            validation_mapped_query_embeddings = mapper(validation_query_embeddings)
            validation_retrieved_passage_indices = ms_marco_semantic_searcher.search(validation_mapped_query_embeddings.cpu().numpy())
            _, _, avg_validation_optimistic_mrr, avg_validation_pessimistic_mrr = ms_marco_dataset.get_metrics(validation_query_index_list, validation_retrieved_passage_indices)

            pbar.set_postfix_str(f'train (loss: {avg_loss:.4f}, o-mrr: {avg_train_optimistic_mrr:.4f}, p-mrr: {avg_train_pessimistic_mrr:.4f}), validation (o-mrr: {avg_validation_optimistic_mrr:.4f}, p-mrr: {avg_validation_pessimistic_mrr:.4f})', refresh=True)

    if avg_validation_pessimistic_mrr > best_validation_avg_pessimistic_mrr:
        best_mapper = copy.deepcopy(mapper)
        best_validation_avg_pessimistic_mrr = avg_validation_pessimistic_mrr
        total_epochs_since_improvement = 0
    else:
        total_epochs_since_improvement += 1
    if total_epochs_since_improvement >= patience:
        break

Epoch 01/50: 100%|██████████| 1/1 [00:00<00:00,  3.90it/s, train (loss: 0.1115, o-mrr: 0.9898, p-mrr: 0.7792), validation (o-mrr: 1.0000, p-mrr: 0.7972)]
Epoch 02/50: 100%|██████████| 1/1 [00:00<00:00,  3.93it/s, train (loss: 0.1068, o-mrr: 0.9852, p-mrr: 0.7902), validation (o-mrr: 1.0000, p-mrr: 0.8029)]
Epoch 03/50: 100%|██████████| 1/1 [00:00<00:00,  3.84it/s, train (loss: 0.1025, o-mrr: 0.9861, p-mrr: 0.8226), validation (o-mrr: 1.0000, p-mrr: 0.8061)]
Epoch 04/50: 100%|██████████| 1/1 [00:00<00:00,  4.07it/s, train (loss: 0.0993, o-mrr: 0.9864, p-mrr: 0.8069), validation (o-mrr: 1.0000, p-mrr: 0.8093)]
Epoch 05/50: 100%|██████████| 1/1 [00:00<00:00,  4.17it/s, train (loss: 0.0950, o-mrr: 0.9864, p-mrr: 0.8221), validation (o-mrr: 1.0000, p-mrr: 0.8105)]
Epoch 06/50: 100%|██████████| 1/1 [00:00<00:00,  4.04it/s, train (loss: 0.0919, o-mrr: 0.9867, p-mrr: 0.8121), validation (o-mrr: 1.0000, p-mrr: 0.8106)]
Epoch 07/50: 100%|██████████| 1/1 [00:00<00:00,  3.83it/s, train (loss: 0.08

## Server

In [None]:
ngrok.set_auth_token(NGROK_AUTHENTICATION_TOKEN)
public_url = ngrok.connect(INTERNAL_PORT).public_url
print(f'Public URL: {public_url}')

app = fastapi.FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

@app.get('/')
async def serve_index() -> fastapi.responses.HTMLResponse:
    with open('index.html', 'r') as html_file:
        html_content = html_file.read()
    html_content = html_content.replace('<PUBLIC_URL>', public_url)
    return fastapi.responses.HTMLResponse(content=html_content)

@app.get('/{filename}')
async def serve_static(filename: str) -> fastapi.responses.FileResponse:
    if not os.path.exists(filename):
        return fastapi.responses.JSONResponse(content={'error': 'File not found'}, status_code=404)
    return fastapi.responses.FileResponse(filename)

@app.post('/request')
async def receive_data(request: fastapi.Request) -> fastapi.responses.JSONResponse:
    data:Any = await request.json()
    if data == 'SEND_QUERIES':
        random_query_index_list = random.sample(range(len(ms_marco_dataset.query_list)), 10)
        result = {query_index: ms_marco_dataset.query_list[query_index] for query_index in random_query_index_list}
    else:
        chosen_query_index_list = [int(data)]
        chosen_query = ms_marco_dataset.query_list[chosen_query_index_list[0]]
        chosen_query_embedding = torch.from_numpy(ms_marco_b3_query_embeddings[chosen_query_index_list, :]).to(device=DEVICE)
        mapper.eval()
        with torch.no_grad():
            chosen_query_mapped_query_embedding = mapper(chosen_query_embedding)
        correct_passage_index_list = ms_marco_semantic_searcher.search(chosen_query_mapped_query_embedding.cpu().numpy())[0, :5].tolist()
        correct_passage_list = [ms_marco_dataset.passage_list[passage_index] for passage_index in correct_passage_index_list]
        correct_answer = get_llm_response(chosen_query, correct_passage_list)
        wrong_passage_index_list = random.sample(range(len(ms_marco_dataset.passage_list)), random.randint(3, 7))
        wrong_passage_list = [ms_marco_dataset.passage_list[passage_index] for passage_index in wrong_passage_index_list]
        wrong_answer = get_llm_response(chosen_query, wrong_passage_list)
        result = {
            'good': {'docs': correct_passage_list, 'answer': correct_answer},
            'bad': {'docs': wrong_passage_list, 'answer': wrong_answer}
        }
    return fastapi.responses.JSONResponse(content=result)

nest_asyncio.apply()
uvicorn.run(app, host='0.0.0.0', port=INTERNAL_PORT)