In [3]:
from sentence_transformers import SentenceTransformer, util, CrossEncoder
import torch
from datasets import load_dataset
import pandas as pd
pd.set_option("display.max_colwidth", 1400)
from tqdm import tqdm

In [4]:
#model_path = "output/train_bi-encoder-triplet-sentence-transformers-all-MiniLM-L12-v2-2024-06-10_11-07-27/2010/"
model_path = "output/train_bi-encoder-triplet-sentence-transformers-all-MiniLM-L12-v2-2024-06-17_06-30-11/268/"

embedder = SentenceTransformer(model_path)

In [5]:
baseline_embedder = SentenceTransformer("all-MiniLM-L12-v2")



In [6]:
baseline_cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

  return self.fget.__get__(instance, owner)()


In [7]:
tuned_cross_encoder = CrossEncoder('output/training_esci-crossencoder-exact_vs_nonexact2024-06-17 07:16:25.325538/')

In [8]:
esci = load_dataset("tasksource/esci")
esci_test = esci['test'].to_pandas().head(20000).tail(19000)
#esci_test[esci_test.esci_label == 'Irrelevant'][['query', 'query_id']].sample(10)

In [9]:
# We also compare the results to lexical search (keyword search). Here, we use 
# the BM25 algorithm which is implemented in the rank_bm25 package.

from rank_bm25 import BM25Okapi
from sklearn.feature_extraction import _stop_words
import string
from tqdm.autonotebook import tqdm
import numpy as np


# We lower case our text and remove stop-words from indexing
def bm25_tokenizer(text):
    tokenized_doc = []
    for token in text.lower().split():
        token = token.strip(string.punctuation)

        if len(token) > 0 and token not in _stop_words.ENGLISH_STOP_WORDS:
            tokenized_doc.append(token)
    return tokenized_doc


In [10]:
def create_item_text(row):
    title = row["product_title"]
    brand = row["product_brand"]
    color = row["product_color"]
    return f"Title: {title}. Brand: {brand}. Color: {color}"
    #return f"{title}"
    
def retrieve_items(embedder, query_id, cross_encoder, recall_size=25, use_cross_encoder=False):
    """
    query_id: query id for which items will be retrieved.
    topk : number of candidate items generated by semantic retrieval
    """
    test_query_id = query_id
    esci_test_sample = esci_test[(esci_test.query_id == test_query_id)].copy().head(recall_size)
    
    esci_test_sample['item_text'] = esci_test_sample.apply(lambda row: create_item_text(row), axis=1)
    esci_test_sample['item_embeddings'] = esci_test_sample['item_text'].apply(lambda x: embedder.encode(x, convert_to_tensor=True))
    
    query_embedding = embedder.encode(esci_test_sample['query'].values[0], convert_to_tensor=True)
    esci_test_sample['cos_score'] = esci_test_sample['item_embeddings'].apply(lambda item_embed:
                                    util.cos_sim(query_embedding, item_embed)[0].item()                  
                                    )

    if use_cross_encoder:
        esci_test_sample['cross_enc_score'] = esci_test_sample.apply(lambda row: 
                                                    cross_encoder.predict([(row["query"], row["item_text"])])[0], axis=1)
        esci_test_sample = esci_test_sample.sort_values(by=['cross_enc_score'], ascending=False)
    
    else:
        esci_test_sample = esci_test_sample.sort_values(by=['cos_score'], ascending=False)
    

    return esci_test_sample[['query','product_id','product_title', 'esci_label', 'cos_score']]

def retrieve_items_bm25(query_id, recall_size=25):
    test_query_id = query_id
    esci_test_sample = esci_test[(esci_test.query_id == test_query_id)].copy().head(recall_size)
    esci_test_sample['item_text'] = esci_test_sample.apply(lambda row: create_item_text(row), axis=1)
    esci_test_sample['item_text_tokenized'] = esci_test_sample['item_text'].apply(lambda text: bm25_tokenizer(text))

    bm25 = BM25Okapi(esci_test_sample['item_text_tokenized'].tolist())

    query = esci_test_sample['query'].values[0]
    
    bm25_scores = bm25.get_scores(bm25_tokenizer(query))
    esci_test_sample['bm25_scores'] = bm25_scores
    esci_test_sample = esci_test_sample.sort_values(by=['bm25_scores'], ascending=False)
    
    return esci_test_sample[['query','product_id','product_title', 'esci_label', 'bm25_scores']]




In [15]:
def precision_at(k, ranked_list, eval_criteria="Exact"):
    return sum([rank == eval_criteria for rank in ranked_list[:k]])

def reciprocal_rank(ranked_list, eval_criteria="Exact"):
    for i, rank in enumerate(ranked_list):
        if rank == eval_criteria:
            return 1/(i+1)
    return 0

def eval_search_sample(embedder, eval_df, cross_encoder, use_cross_encoder, k=10, recall_size=25, eval_criteria="Exact"):
    mean_precision_at_k = []
    mean_reciprocal_rank = []
    
    for query_id in tqdm(eval_df.query_id):
        retrieved_sample = retrieve_items(embedder, query_id, cross_encoder, use_cross_encoder=use_cross_encoder, recall_size=recall_size)
        
        precision = precision_at(k, retrieved_sample.esci_label.tolist(), eval_criteria)
        mean_precision_at_k.append(precision)

        reci_rank = reciprocal_rank(retrieved_sample.esci_label.tolist(), eval_criteria)
        mean_reciprocal_rank.append(reci_rank)

    mean_precision_at_k_val = sum(mean_precision_at_k)/len(mean_precision_at_k)
    mean_reciprocal_rank_val = sum(mean_reciprocal_rank)/len(mean_reciprocal_rank)
    
    return {f"mean_precision_at_{k}": mean_precision_at_k_val, "mean_reciprocal_rank": mean_reciprocal_rank_val}

def eval_bm25_search(eval_df, k=10, recall_size=25, eval_criteria="Exact"):
    mean_precision_at_k = []
    mean_reciprocal_rank = []

    for query_id in tqdm(eval_df.query_id):
        retrieved_sample = retrieve_items_bm25(query_id, recall_size=recall_size)
        precision = precision_at(k, retrieved_sample.esci_label.tolist(), eval_criteria)
        mean_precision_at_k.append(precision)

        reci_rank = reciprocal_rank(retrieved_sample.esci_label.tolist(), eval_criteria)
        mean_reciprocal_rank.append(reci_rank)

    mean_precision_at_k_val = sum(mean_precision_at_k)/len(mean_precision_at_k)
    mean_reciprocal_rank_val = sum(mean_reciprocal_rank)/len(mean_reciprocal_rank)
    
    return {f"mean_precision_at_{k}": mean_precision_at_k_val, "mean_reciprocal_rank": mean_reciprocal_rank_val}


In [12]:
eval_set = esci_test[(esci_test.esci_label == 'Irrelevant') & (esci_test.product_locale == 'us')][['query', 'query_id']]
eval_set = eval_set.drop_duplicates(subset=["query_id"]).sample(1000, random_state=19)
eval_set.shape

(1000, 2)

In [13]:
eval_bm25_search(eval_set, k=5, recall_size=15, eval_criteria="Exact")

  0%|          | 0/1000 [00:00<?, ?it/s]

{'mean_precision_at_5': 0.705, 'mean_reciprocal_rank': 0.25967825785325793}

In [16]:
eval_search_sample(baseline_embedder, eval_set, baseline_cross_encoder, use_cross_encoder=False, recall_size=15, k=5, eval_criteria="Exact")

  0%|          | 0/1000 [00:00<?, ?it/s]

{'mean_precision_at_5': 0.716, 'mean_reciprocal_rank': 0.2666905233655234}

In [17]:
eval_search_sample(embedder, eval_set, baseline_cross_encoder, use_cross_encoder=False, recall_size=15, k=5, eval_criteria="Exact")

  0%|          | 0/1000 [00:00<?, ?it/s]

{'mean_precision_at_5': 0.734, 'mean_reciprocal_rank': 0.27109484126984124}

In [18]:
eval_search_sample(embedder, eval_set, baseline_cross_encoder, use_cross_encoder=True, recall_size=15, k=5, eval_criteria="Exact")

  0%|          | 0/1000 [00:00<?, ?it/s]

{'mean_precision_at_5': 0.737, 'mean_reciprocal_rank': 0.27779047619047614}

In [19]:
eval_search_sample(embedder, eval_set, tuned_cross_encoder, use_cross_encoder=True, recall_size=15, k=5, eval_criteria="Exact")

  0%|          | 0/1000 [00:00<?, ?it/s]

{'mean_precision_at_5': 0.796, 'mean_reciprocal_rank': 0.28968477355977346}

In [26]:
# def precision_at(k, ranked_list, eval_criteria="Exact"):
#     return sum([rank == eval_criteria for rank in ranked_list[:k]])

# def reciprocal_rank(ranked_list, eval_criteria="Exact"):
#     for i, rank in enumerate(ranked_list):
#         if rank == eval_criteria:
#             return 1/(i+1)
#     return 0

# def eval_search_sample(embedder, eval_df, cross_encoder, use_cross_encoder, k=10, recall_size=25, eval_criteria="Exact"):
#     mean_precision_at_k = []
#     mean_reciprocal_rank = []
    
#     for query_id in tqdm(eval_df.query_id):
#         retrieved_sample = retrieve_items(embedder, query_id, cross_encoder, use_cross_encoder=use_cross_encoder, recall_size=recall_size)
        
#         precision = precision_at(k, retrieved_sample.esci_label.tolist(), eval_criteria)
#         mean_precision_at_k.append(recall)

#         reci_rank = reciprocal_rank(retrieved_sample.esci_label.tolist(), eval_criteria)
#         mean_reciprocal_rank.append(reci_rank)

#     mean_precision_at_k_val = sum(mean_precision_at_k)/len(mean_precision_at_k)
#     mean_reciprocal_rank_val = sum(mean_reciprocal_rank)/len(mean_reciprocal_rank)
    
#     return {f"mean_precision_at_{k}": mean_precision_at_k_val, "mean_reciprocal_rank": mean_reciprocal_rank_val}

# def eval_bm25_search(eval_df, k=10, recall_size=25, eval_criteria="Exact"):
#     mean_precision_at_k = []
#     mean_reciprocal_rank = []

#     for query_id in tqdm(eval_df.query_id):
#         retrieved_sample = retrieve_items_bm25(query_id, recall_size=recall_size)
#         recall = precision_at(k, retrieved_sample.esci_label.tolist(), eval_criteria)
#         mean_precision_at_k.append(recall)

#         reci_rank = reciprocal_rank(retrieved_sample.esci_label.tolist(), eval_criteria)
#         mean_reciprocal_rank.append(reci_rank)

#     mean_precision_at_k_val = sum(mean_precision_at_k)/len(mean_precision_at_k)
#     mean_reciprocal_rank_val = sum(mean_reciprocal_rank)/len(mean_reciprocal_rank)
    
#     return {f"mean_precision_at_{k}": mean_precision_at_k_val, "mean_reciprocal_rank": mean_reciprocal_rank_val}
        