# Search embedding

In [1]:
# !pip install -U sentence-transformers

## Load dataset

In [1]:
# Here we store constant varables with path to our datasets
# SEARCH_DATA_INDEX_PATH = "../../datasets" #"../../datasets/metadata_for_search"
# MODEL_PATH = "../../models/"

# SEARCH_DATA_PATH = "../../datasets"
# Here we store constant varables with path to our datasets
SEARCH_DATA_INDEX_PATH = "../../datasets/processed_datasets"
MODEL_PATH = "../../models/"

SEARCH_DATA_PATH = "../../datasets"

#### load search data

In [2]:
import pandas as pd
data_vote = pd.read_json(SEARCH_DATA_PATH+"/search_dataset_vote.json")
data_active = pd.read_json(SEARCH_DATA_PATH+"/search_dataset_active.json")
data_hot = pd.read_json(SEARCH_DATA_PATH+"/search_dataset_hot.json")

search_data_index_hot = pd.read_csv(SEARCH_DATA_INDEX_PATH+"/metadata-for-search-info-hottest-rating.csv")
search_data_index_vote = pd.read_csv(SEARCH_DATA_INDEX_PATH+"/metadata-for-search-info.csv")
search_data_index_active = pd.read_csv(SEARCH_DATA_INDEX_PATH+"/metadata-for-search-info-active-rating.csv")

In [3]:
search_data_index_active

Unnamed: 0,query name,rating_active,path to metadata
0,Top 1000 Movies Dataset,1,ritiksharma07/imdb-top-1000-movies-dataset
1,Top 1000 Movies Dataset,2,emanchauhdary/imdb-top-1000-dataset
2,Top 1000 Movies Dataset,3,sharathmudigoudr/imdb-movie-dataset-from-year-...
3,Top 1000 Movies Dataset,4,kunalduttads/tmdb-top-10000-10k-movies-dataset
4,Top 1000 Movies Dataset,5,dogacelik/google-data-analytics-capstone-proje...
...,...,...,...
452,classification,16,nourmekkijj/stress-and-anxiety-posts-on-reddit
453,classification,17,sabunbalt/skin-disease-classification
454,classification,18,carloscortes18/players-chess-pieces
455,classification,19,haicouheba/car-classification


In [4]:
data_hot['id'], data_hot['text']

(0        0
 1        1
 2        2
 3        3
 4        4
       ... 
 684    685
 685    686
 686    687
 687    688
 688    689
 Name: id, Length: 689, dtype: int64,
 0       Top 10000 Popular Movies Dataset Top 10000 Po...
 1       IMDB Top 1000 Movies Dataset "Release Year, D...
 2       Top 1000 IMDb Movies Dataset Discover the Gre...
 3       Top Rated 10000 Movies Dataset (IMDB) In this...
 4       Top 1000 Highest Grossing Movies Top 1000 Hig...
                              ...                        
 684     Scene Classification Contains ~25K images fro...
 685     Plants Classification 30 Types of Plants Imag...
 686     Flower Classification 14 Types of Flower Imag...
 687     BIRDS 525  SPECIES- IMAGE CLASSIFICATION 525 ...
 688     Brain Tumor Classification (MRI) Classify MRI...
 Name: text, Length: 689, dtype: object)

In [5]:
queries = search_data_index_vote['query name'].unique()

## Inference batching utils

In [7]:
def iterate_batch(dataset,batch_size=128):
    for i in range(0,len(dataset),batch_size):
        yield dataset[i:i+batch_size]

In [8]:
iterate_batch([1,2,3,4,5,6,7,8,9,10],3)

<generator object iterate_batch at 0x7f96b466c510>

## Models

### Standart models

In [9]:
import gc

#### top 1 Salesforce/SFR-Embedding-2_R

In [10]:
import torch
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel

tokenizer_SFR = AutoTokenizer.from_pretrained('Salesforce/SFR-Embedding-2_R',cache_dir=MODEL_PATH)
model_SFR = AutoModel.from_pretrained('Salesforce/SFR-Embedding-2_R',cache_dir=MODEL_PATH)

def sfr_model_inference(text=None,dataset=None):
    global tokenizer_SFR
    global model_SFR
    def last_token_pool(last_hidden_states: Tensor,
                     attention_mask: Tensor) -> Tensor:
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            return last_hidden_states[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_states.shape[0]
            return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
    
    def get_detailed_instruct(task_description: str, query: str) -> str:
        return f'Instruct: {task_description}\nQuery: {query}'

    # Query process
    # Each query must come with a one-sentence instruction that describes the task
    # task = 'Given a web search query, retrieve relevant passages that answer the query'
    # queries = [
    #     get_detailed_instruct(task, '<Query1>'),
    #     get_detailed_instruct(task, '<Query2>')
    # ]

    with torch.no_grad():
        # load model and tokenizer
        tokenizer = tokenizer_SFR
        model = model_SFR
        model = model.to(device)
        embeddings = 0
        model.eval()

        if text is None:
            text = list(dataset["text_str"])
        for batch in iterate_batch(text,10):
            # get the embeddings
            embedding=0
            try:
                model = model.to(device)
                max_length = 4096
                batch_dict = tokenizer(batch, max_length=max_length, padding=True, truncation=True, return_tensors="pt").to(device)
                outputs = model(**batch_dict)
                embedding = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
            
                # normalize embeddings
                embedding = F.normalize(embedding, p=2, dim=1)
                embedding = embedding.to('cpu')
    
                if type(embeddings) is int:
                    embeddings = embedding
                else:
                    embeddings = torch.cat((embeddings,embedding),dim=0)
            except torch.cuda.OutOfMemoryError:
                del model
                del batch_dict
                del embedding
                torch.cuda.empty_cache()
                raise torch.cuda.OutOfMemoryError
            del embedding
            del batch_dict
            torch.cuda.empty_cache()
        del model
        del tokenizer
        torch.cuda.empty_cache()
        return embeddings

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(
Loading checkpoint shards: 100%|██████████| 3/3 [00:09<00:00,  3.13s/it]


### Load vector database

In [11]:
all_results_hot = []
all_results_vote = []
all_results_active = []

In [12]:
metrics = {}

In [13]:
import torch
from typing import Optional, cast

import numpy as np
import numpy.typing as npt
from chromadb.api.types import EmbeddingFunction, Documents, Embeddings


class TransformerEmbeddingFunction(EmbeddingFunction[Documents]):
    def __init__(
            self, model_name
    ):
        self.model_name=model_name
    def __call__(self, input: Documents) -> Embeddings:

        return run_gte_model(text=input,model_path=model_name).tolist()

    def preprocess_query(self,text_list):
        return text_list

class TransformerEmbeddingFunction_SFR(EmbeddingFunction[Documents]):
    def __init__(
            self
    ):
        return
    def __call__(self, input: Documents) -> Embeddings:

        return sfr_model_inference(text=input).tolist()
    
    def preprocess_query(self,text_list):
        text_list = [ f"Instruct: Given a web search query,\
                      retrieve relevant passages that answer the query\nQuery: {text}"
            for text in text_list]
        return text_list

In [14]:
# model_name="thenlper/gte-large"
# model_embed = TransformerEmbeddingFunction(model_name)
model_name="SFR"
model_embed = TransformerEmbeddingFunction_SFR()
device = "cuda"

In [15]:
model_name

'SFR'

In [16]:
load_time = 0
search_time = 0

In [17]:
# client.delete_collection("seach_db_vote")
# client.delete_collection("seach_db_active")
# client.delete_collection("seach_db_hot")

In [18]:
import chromadb
client = chromadb.Client()

In [None]:
import chromadb
import time

metrics[model_name] = {"active":{},"vote":{},"hot":{}}

client = chromadb.Client()

collection_vote = client.create_collection("seach_db_vote",get_or_create=False,embedding_function=model_embed)
collection_active = client.create_collection("seach_db_active",get_or_create=False,embedding_function=model_embed)
collection_hot = client.create_collection("seach_db_hot",get_or_create=False,embedding_function=model_embed)
# collection = client.create_collection("test_seach_fb",embedding_function,get_or_create=True)

# load
load_time = time.time()
collection_vote.add(
    documents=list(data_vote['text']),
    ids=[str(i) for i in data_vote['id']]
)
load_time = time.time() - load_time
metrics[model_name]['vote']['load_time'] = load_time

load_time = time.time()
collection_active.add(
    documents=list(data_active['text']),
    ids=[str(i) for i in data_active['id']]
)
load_time = time.time() - load_time
metrics[model_name]['active']['load_time'] = load_time

load_time = time.time()
collection_hot.add(
    documents=list(data_hot['text']),
    ids=[str(i) for i in data_hot['id']]
)
load_time = time.time() - load_time
metrics[model_name]['hot']['load_time'] = load_time

In [19]:
search_time = time.time()
queries = list(search_data_index_vote['query name'].unique())
queries = model_embed.preprocess_query(queries)
results_vote = collection_vote.query(
    query_texts=queries, # Chroma will embed this for you
    n_results=40 # how many results to return
)
search_time = time.time() - search_time
metrics[model_name]['vote']['search_time'] = search_time

search_time = time.time()
queries = list(search_data_index_hot['query name'].unique())
queries = model_embed.preprocess_query(queries)
results_hot = collection_hot.query(
    query_texts=queries, # Chroma will embed this for you
    n_results=40 # how many results to return
)
search_time = time.time() - search_time
metrics[model_name]['hot']['search_time'] = search_time

search_time = time.time()
queries = list(search_data_index_active['query name'].unique())
queries = model_embed.preprocess_query(queries)
results_active = collection_active.query(
    query_texts=queries, # Chroma will embed this for you
    n_results=40 # how many results to return
)
search_time = time.time() - search_time
metrics[model_name]['active']['search_time'] = search_time



### Inference

In [None]:
collection_hot.query(query_texts=["anime","recomendation system"],n_results=5)

In [None]:
client.delete_collection("seach_db_vote")
client.delete_collection("seach_db_active")
client.delete_collection("seach_db_hot")

## Evaluation

Evaluation
https://www.cs.cmu.edu/~ytsvetko/papers/qvec.pdf

In [20]:
import sys
sys.path.insert(0, '../../scripts/')

In [21]:
from evaluation import evaluate_algs

In [42]:
from ranx import Qrels, Run
"""
This code provides functions for creating and evaluating
document ranking systems using the ranx library
"""

def create_qrels(index_data,column_name):
    """
    This function creates a Qrels object from the index_data DataFrame.
    It iterates over the unique queries in the index_data DataFrame.
    For each query, it creates a dictionary mapping document IDs to their corresponding ratings.
    The function returns a Qrels object containing the query-document relevance scores.
    """
    qrels_dict = {}
    # index_data = search_data_index_hot
    queries = index_data['query name'].unique()
    
    for i in range(len(queries)):
        current_queries = queries[i]
        query_dict = {}
        current_index = index_data[index_data['query name']==current_queries]
        for doc_idx,rating in zip(current_index.index,current_index[column_name]):
            query_dict[str(doc_idx)] = rating
        qrels_dict[str(i)] = query_dict
    return Qrels(qrels_dict)

def create_run(results,index_data):
    """
    This function creates a Run object from the results dictionary and index_data DataFrame.
    It iterates over the unique queries in the index_data DataFrame.
    For each query, it creates a dictionary mapping document IDs to their corresponding scores (calculated as 1 / (distance + 1e-10)).
    The function returns a Run object containing the query-document scores.
    """
    runs = {}
    queries = index_data['query name'].unique()
    results
    for i in range(len(queries)):
        
        res_dict = {}
        for doc_idx,rating in zip(results['ids'][i],results['distances'][i]):
            res_dict[str(doc_idx)] = 1/(rating+1e-10)
        runs[str(i)] = res_dict
    return Run(runs)


from ranx import compare, evaluate

def evaluate_algs(results,index_data,column_name):
    """
    This function evaluates the performance of an information retrieval system.
    It calls the create_qrels and create_run functions to create the Qrels and Run objects.
    It defines a list of evaluation metrics to be used, including recall, precision, hits, MAP, MRR, and NDCG at ranks 10 and 40.
    The function returns the evaluation results using the evaluate function from ranx.
    """
    qrels = create_qrels(index_data,column_name)
    run = create_run(results,index_data)
    
    metrics = ["recall@10","precision@10","hits@10","map@10", "mrr@10", "ndcg@10"]
    metrics += ["recall@40","precision@40","hits@40","map@40", "mrr@40", "ndcg@40"]
    return evaluate(qrels, run, metrics,make_comparable=True)


In [49]:
def compare_algs(results_list,index_data,column_name):
    metrics = ["recall@10","precision@10","hits@10","map@10", "mrr@10", "ndcg@10"]
    metrics += ["recall@40","precision@40","hits@40","map@40", "mrr@40", "ndcg@40"]
    
    qrels = create_qrels(index_data,column_name)
    runs = [create_run(result,index_data) for result in results_list]
    
    return compare(
    qrels=qrels,
    runs=runs,
    metrics=metrics,
    max_p=0.01,  # P-value threshold
    make_comparable=True
)

In [43]:
metrics[model_name]['vote']['metrics'] = evaluate_algs(results_vote,search_data_index_vote,column_name='rating')
metrics[model_name]['active']['metrics'] = evaluate_algs(results_active,search_data_index_active,column_name='rating_active')
metrics[model_name]['hot']['metrics'] = evaluate_algs(results_hot,search_data_index_hot,column_name='rating_hottest')

In [55]:
metrics[model_name]['vote']['metrics']

{'recall@10': 0.24509114992859457,
 'precision@10': 0.5978260869565217,
 'hits@10': 5.978260869565218,
 'map@10': 0.19335366085282796,
 'mrr@10': 0.8148550724637681,
 'ndcg@10': 0.38780801956713495,
 'recall@40': 0.5591251293854966,
 'precision@40': 0.46304347826086967,
 'hits@40': 18.52173913043478,
 'map@40': 0.4029660422193646,
 'mrr@40': 0.8159992372234935,
 'ndcg@40': 0.48027479681135155}

In [54]:
metrics[model_name]['active']['metrics']

{'recall@10': 0.5775219298245614,
 'precision@10': 0.5605263157894737,
 'hits@10': 5.605263157894737,
 'map@10': 0.46619869987468665,
 'mrr@10': 0.8173976608187135,
 'ndcg@10': 0.5564827486649929,
 'recall@40': 0.8576597744360903,
 'precision@40': 0.25789473684210523,
 'hits@40': 10.31578947368421,
 'map@40': 0.6390193875472306,
 'mrr@40': 0.8173976608187135,
 'ndcg@40': 0.6518684108367149}

In [53]:
metrics[model_name]['hot']['metrics']

{'recall@10': 0.4800489914126278,
 'precision@10': 0.6568181818181817,
 'hits@10': 6.568181818181818,
 'map@10': 0.4212802985354392,
 'mrr@10': 0.9128787878787878,
 'ndcg@10': 0.4656944557747741,
 'recall@40': 0.879333103764922,
 'precision@40': 0.3414772727272727,
 'hits@40': 13.659090909090908,
 'map@40': 0.6726354645065095,
 'mrr@40': 0.9128787878787878,
 'ndcg@40': 0.6368186735738115}