In [None]:
from sentence_transformers import SentenceTransformer, CrossEncoder, util, models
from sklearn.feature_extraction import _stop_words as stop_words
from tqdm.notebook import tqdm
from rank_bm25 import BM25Okapi

import torch
import string
import json
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords

import numpy as np
import pandas as pd

In [None]:
bi_enc_weights = './models/bienc-exp7/'
cr_enc_weights = './models/crenc-exp7/'
data_folder = 'generated5'
top_k = 50
use_base = False

In [None]:
df = pd.read_excel('./data/20231004_data.xlsx', index_col=0)
df.head(2)

In [None]:
english_stopwords = set(stopwords.words('english'))

def bm25_tokenizer(text):
  tokenized_doc = []
  for token in text.lower().split():
    token = token.strip(string.punctuation)

    if len(token) > 0 and token not in english_stopwords:
      tokenized_doc.append(token)
      
  return tokenized_doc

In [None]:
if use_base:
    word_embedding_model = models.Transformer('distilroberta-base', max_seq_length=350)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
    cr_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-6')
else:
    bi_encoder = SentenceTransformer(bi_enc_weights)
    cr_encoder = CrossEncoder(cr_enc_weights)

In [None]:
with open(f'./data/{data_folder}/test_passage.json', 'r') as f:
    val_passage = json.load(f)

with open(f'./data/{data_folder}/test_corpus.json', 'r') as f:
    val_corpus = json.load(f)

val_query_answer = {}
for idx, rel in val_passage.items():
    query = val_corpus[idx]
    pos = rel[0]
    answers = [val_corpus[str(p)] for p in pos]
    val_query_answer[idx] = answers

In [None]:
import torch
import json
import pandas as pd


def shorten(text):
    tmp = text.split()[:512]
    return ' '.join(tmp)

val_text = list(val_corpus.values())


with open("embeddings_GPT.json", "r") as jsonfile:
    embeddings_dict = json.load(jsonfile)

val_emb_tensors = []


title_to_id = dict(zip(df['Question Title'], df.index))

# Convert the embeddings into tensors
for text in val_text:
    # Retrieve the ID for the given text
    text_id = title_to_id[text]
    embedding = embeddings_dict.get(str(text_id))  
    if embedding:  
        val_emb_tensors.append(torch.tensor(embedding))


val_emb = torch.stack(val_emb_tensors)



In [None]:
from tqdm import tqdm
tokenized_corpus = []
for idx, passage in tqdm(val_corpus.items()):
    tokenized_corpus.append(bm25_tokenizer(passage))

In [None]:
bm25 = BM25Okapi(tokenized_corpus)

In [None]:
def evaluate_bm25(query, answer, top_k=50):
    # Convert answer to set for faster lookup
    answer_set = set(answer)
    
    # Get bm25 scores for the query
    bm25_scores = bm25.get_scores(bm25_tokenizer(query))
    
    # Check if we have enough scores
    if len(bm25_scores) < top_k:
        raise ValueError("Not enough BM25 scores for top_k.")
    
    # Retrieve the top-k indices based on bm25_scores
    top_n = np.argpartition(bm25_scores, -top_k)[-top_k:]
    bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
    bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)

    bm25_counter = -1
    bm25_map = 0
    tmp_hits = 0
    
    # Iterate over top-k hits
    for idx, hit in enumerate(bm25_hits):
        candidate = val_text[hit['corpus_id']]
        if candidate in answer_set:
            if bm25_counter == -1:
                bm25_counter = idx + 1

            tmp_hits += 1
            bm25_map += tmp_hits / (idx + 1)
            answer_set.remove(candidate)

    bm25_map /= len(answer)
    bm25_mrr = 1 / bm25_counter if bm25_counter != -1 else 0.0

    return bm25_map, bm25_mrr

def compute_cosine_similarity(query_embedding, corpus_embeddings):
    # Normalize embeddings
    if len(query_embedding.shape) == 1:
        query_embedding = query_embedding / torch.norm(query_embedding, keepdim=True)
        query_embedding = query_embedding.unsqueeze(0)  # Add an additional dimension
    else:
        query_embedding = query_embedding / torch.norm(query_embedding, dim=1, keepdim=True)
    
    corpus_embeddings = corpus_embeddings / torch.norm(corpus_embeddings, dim=1, keepdim=True)
    
    # Compute cosine similarity
    cosine_similarities = torch.mm(query_embedding, corpus_embeddings.transpose(0, 1))
    
    return cosine_similarities



def forward_pass_rerank(query, precomputed_embedding=None, val_embeddings=None, top_k=50):
    
    if precomputed_embedding is None:
        q_emb = bi_encoder.encode(query, convert_to_tensor=True)
    else:
        q_emb = precomputed_embedding

    # Ensure the query embedding is 2-dimensional
    if len(q_emb.shape) == 1:
        q_emb = q_emb.unsqueeze(0)

    if val_embeddings is None:
        raise ValueError("No embeddings provided for the validation set.")
    
   
    
    if val_embeddings.shape[1] != q_emb.shape[1]:
        val_embeddings = val_embeddings.transpose(0, 1)
    
    cosine_similarities = compute_cosine_similarity(q_emb, val_embeddings)
    

    top_indices = torch.topk(cosine_similarities, k=top_k+1, dim=1).indices[0].tolist()
    hits = [{'corpus_id': index, 'score': cosine_similarities[0, index].item()} for index in top_indices]
    

    cross_inputs = []
    
    to_remove = -1
    for hit in hits:
        text = val_text[hit['corpus_id']]
        if query == text:
            to_remove = hits.index(hit)
        cross_inputs.append([query, text])
        
    cross_scores = cr_encoder.predict(cross_inputs)
    
    for idx in range(len(cross_scores)):
        hits[idx]['cross_score'] = cross_scores[idx]
        
    if to_remove != -1: 
        del hits[to_remove]
    hits = hits[:top_k]

    return hits




def evaluate_bi_encoder(hits, answer, top_k=50):
    answer_set = set(answer)
    bi_enc_counter = -1
    bi_enc_map = 0
    tmp_hits = 0
    bi_enc_hit_list = [0] * top_k
    bi_enc_hit_recall_list = [0] * top_k
    hits = sorted(hits, key=lambda x: x['score'], reverse=True)
    
    for idx, hit in enumerate(hits):
        candidate = val_text[hit['corpus_id']]
        if candidate in answer_set:
            if bi_enc_counter == -1:
                bi_enc_counter = idx + 1
            bi_enc_hit_list[idx] = 1
            bi_enc_hit_recall_list[idx] = 1
            tmp_hits += 1
            bi_enc_map += tmp_hits / (idx + 1)
            answer_set.remove(candidate)

    bi_enc_map /= len(answer)
    bi_enc_mrr = 1 / bi_enc_counter if bi_enc_counter != -1 else 0.0

    return bi_enc_map, bi_enc_mrr, bi_enc_hit_list, bi_enc_hit_recall_list

def evaluate_cr_encoder(hits, answer, top_k=50, mode='cross_score'):
    answer_set = set(answer)
    cr_enc_counter = -1
    cr_enc_map = 0
    tmp_hits = 0
    cr_enc_hit_list = [0] * top_k
    cr_enc_hit_recall_list = [0] * top_k
    hits = sorted(hits, key=lambda x: x[mode], reverse=True)
    
    for idx, hit in enumerate(hits):
        candidate = val_text[hit['corpus_id']]
        if candidate in answer_set:
            if cr_enc_counter == -1:
                cr_enc_counter = idx + 1
            cr_enc_hit_list[idx] = 1
            cr_enc_hit_recall_list[idx] = 1
            tmp_hits += 1
            cr_enc_map += tmp_hits / (idx + 1)
            answer_set.remove(candidate)

    cr_enc_map /= len(answer)
    cr_enc_mrr = 1 / cr_enc_counter if cr_enc_counter != -1 else 0.0

    return cr_enc_map, cr_enc_mrr, cr_enc_hit_list, cr_enc_hit_recall_list

bm25_scores = {'mrr': 0, 'map': 0}
bi_enc_scores = {'mrr': 0, 'map': 0, 'precision': [0] * 4, 'recall': [0] * 4}
cr_enc_scores = {'mrr': 0, 'map': 0, 'precision': [0] * 4, 'recall': [0] * 4}

for query_key, answers in tqdm(val_query_answer.items(), total=len(val_query_answer)):
    query = val_corpus[query_key]
    
    # Fetch precomputed embedding
    query_id = title_to_id[query]
    precomputed_query_embedding = torch.tensor(embeddings_dict.get(str(query_id)))
    hits = forward_pass_rerank(query, precomputed_embedding=precomputed_query_embedding, val_embeddings=val_emb, top_k=top_k)

    
    
    bm25_map, bm25_mrr = evaluate_bm25(query, answers, top_k)
    b_map, b_mrr, b_hit, b_rec = evaluate_bi_encoder(hits, answers)
    c_map, c_mrr, c_hit, c_rec = evaluate_cr_encoder(hits, answers)
    
    tmp_precision = [0] * 4
    tmp_recall = [0] * 4
    for idx, n in enumerate([1, 3, 5, 10]):
        tmp_precision[idx] = sum(b_hit[:n]) / n
        tmp_recall[idx] = sum(b_rec[:n]) / len(answers)

    bi_enc_scores['precision'] = [x + y for (x, y) in zip(bi_enc_scores['precision'], tmp_precision)] 
    bi_enc_scores['recall'] = [x + y for (x, y) in zip(bi_enc_scores['recall'], tmp_recall)]

    tmp_precision = [0] * 4
    tmp_recall = [0] * 4
    for idx, n in enumerate([1, 3, 5, 10]):
        tmp_precision[idx] = sum(c_hit[:n]) / n
        tmp_recall[idx] = sum(c_rec[:n]) / len(answers)

    cr_enc_scores['precision'] = [x + y for (x, y) in zip(cr_enc_scores['precision'], tmp_precision)] 
    cr_enc_scores['recall'] = [x + y for (x, y) in zip(cr_enc_scores['recall'], tmp_recall)]

    bm25_scores['map'] += bm25_map
    bm25_scores['mrr'] += bm25_mrr

    bi_enc_scores['map'] += b_map
    bi_enc_scores['mrr'] += b_mrr

    cr_enc_scores['map'] += c_map
    cr_enc_scores['mrr'] += c_mrr

bm25_scores['map'] /= len(val_query_answer)
bm25_scores['mrr'] /= len(val_query_answer)
bi_enc_scores['map'] /= len(val_query_answer)
bi_enc_scores['mrr'] /= len(val_query_answer)
bi_enc_scores['precision'] = [x/len(val_query_answer) for x in bi_enc_scores['precision']]
bi_enc_scores['recall'] = [x/len(val_query_answer) for x in bi_enc_scores['recall']]
cr_enc_scores['map'] /= len(val_query_answer)
cr_enc_scores['mrr'] /= len(val_query_answer)
cr_enc_scores['precision'] = [x/len(val_query_answer) for x in cr_enc_scores['precision']]
cr_enc_scores['recall'] = [x/len(val_query_answer) for x in cr_enc_scores['recall']]

In [None]:
print(f'BM25:\n{json.dumps(bm25_scores, indent=2)}')
print(f'Bi-Encoder:\n{json.dumps(bi_enc_scores, indent=2)}')
print(f'Cross-Encoder:\n{json.dumps(cr_enc_scores, indent=2)}')