In [None]:
from sentence_transformers import SentenceTransformer, CrossEncoder, util, models
from sklearn.feature_extraction import _stop_words as stop_words
from tqdm.notebook import tqdm
from rank_bm25 import BM25Okapi

import torch
import string
import json
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords

import numpy as np
import pandas as pd

In [None]:
bi_enc_weights = './models/bienc-exp7/'
cr_enc_weights = './models/crenc-exp7/'
data_folder = 'generated5'
top_k = 50
use_base = True

In [None]:
df = pd.read_json(f'./data/{data_folder}/dataset.json')
df.head(2)

In [None]:
english_stopwords = set(stopwords.words('english'))

def bm25_tokenizer(text):
  tokenized_doc = []
  for token in text.lower().split():
    token = token.strip(string.punctuation)

    if len(token) > 0 and token not in english_stopwords:
      tokenized_doc.append(token)
      
  return tokenized_doc

In [None]:
if use_base:
    word_embedding_model = models.Transformer('distilroberta-base', max_seq_length=350)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
    cr_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-6')
else:
    bi_encoder = SentenceTransformer(bi_enc_weights)
    cr_encoder = CrossEncoder(cr_enc_weights)

In [None]:
with open(f'./data/{data_folder}/test_passage.json', 'r') as f:
    val_passage = json.load(f)

with open(f'./data/{data_folder}/test_corpus.json', 'r') as f:
    val_corpus = json.load(f)

val_query_answer = {}
for idx, rel in val_passage.items():
    query = val_corpus[idx]
    pos = rel[0]
    answers = [val_corpus[str(p)] for p in pos]
    val_query_answer[idx] = answers

In [None]:
val_text = list(val_corpus.values())
val_emb = bi_encoder.encode(val_text, show_progress_bar=True, convert_to_tensor=True)

In [None]:
from tqdm import tqdm
tokenized_corpus = []
for idx, passage in tqdm(val_corpus.items()):
    tokenized_corpus.append(bm25_tokenizer(passage))

In [None]:
bm25 = BM25Okapi(tokenized_corpus)

In [None]:
# functions to evaluate a single sample

def evaluate_bm25(query, answer, top_k=50):
    # find the top 50 similar questions from the corpus based on bm25_scores
    bm25_scores = bm25.get_scores(bm25_tokenizer(query))
    top_n = np.argpartition(bm25_scores, -top_k)[-top_k:]
    bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
    bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)

    bm25_counter = -1
    bm25_map = 0
    bm25_mrr = 0
    tmp_hits = 0
    answer2 = answer[:]
    for idx, hit in enumerate(bm25_hits):
        candidate = val_text[hit['corpus_id']]
        if candidate in answer:
            if bm25_counter == -1:
                bm25_counter = idx + 1

            tmp_hits += 1
            bm25_map += tmp_hits / (idx + 1)
            answer2.remove(candidate)

    bm25_map /= len(answer)
    if bm25_counter == -1:
        bm25_mrr = 0.0
    else:
        bm25_mrr = 1 / bm25_counter

    return bm25_map, bm25_mrr
        
def forward_pass(query, top_k=50):
    q_emb = bi_encoder.encode(query, convert_to_tensor=True)
    hits = util.semantic_search([q_emb], val_emb, top_k=top_k+1)[0]

    cross_inputs = []
    to_remove = -1
    for hit in hits:
        # idx = val_idx[hit['corpus_id']]
        text = val_text[hit['corpus_id']]
        if query == text:
            to_remove = hits.index(hit)
        cross_inputs.append([query, text])
    cross_scores = cr_encoder.predict(cross_inputs)

    for idx in range(len(cross_scores)):
        hits[idx]['cross_score'] = cross_scores[idx]

    if to_remove != -1: del hits[to_remove]
    hits = hits[:top_k]

    return hits

def evaluate_bi_encoder(hits, answer, top_k=50):
    bi_enc_counter = -1
    bi_enc_map = 0
    bi_enc_mrr = 0
    tmp_hits = 0
    bi_enc_hit_list = [0] * top_k
    bi_enc_hit_recall_list = [0] * top_k
    answer2 = answer[:]
    hits = sorted(hits, key=lambda x: x['score'], reverse=True)
    for idx, hit in enumerate(hits):
        candidate = val_text[hit['corpus_id']]
        if candidate in answer2:
            if bi_enc_counter == -1: bi_enc_counter = idx + 1
            bi_enc_hit_list[idx] = 1
            bi_enc_hit_recall_list[idx] = 1
            tmp_hits += 1
            bi_enc_map += tmp_hits / (idx + 1)
            answer2.remove(candidate)
    
    bi_enc_map /= len(answer)
    if bi_enc_counter == -1:
        bi_enc_mrr = 0.0
    else:
        bi_enc_mrr = 1 / bi_enc_counter

    return bi_enc_map, bi_enc_mrr, bi_enc_hit_list, bi_enc_hit_recall_list

def evalaute_cr_encoder(hits, answer, top_k=50):
    cr_enc_counter = -1
    cr_enc_map = 0
    cr_enc_mrr = 0
    tmp_hits = 0
    cr_enc_hit_list = [0] * top_k
    cr_enc_hit_recall_list = [0] * top_k
    answer2 = answer[:]
    hits = sorted(hits, key=lambda x: x['cross_score'], reverse=True)
    for idx, hit in enumerate(hits):
        candidate = val_text[hit['corpus_id']]
        if candidate in answer2:
            if cr_enc_counter == -1: cr_enc_counter = idx + 1
            cr_enc_hit_list[idx] = 1
            cr_enc_hit_recall_list[idx] = 1
            tmp_hits += 1
            cr_enc_map += tmp_hits / (idx + 1)
            answer2.remove(candidate)

    cr_enc_map /= len(answer)
    if cr_enc_counter == -1:
        cr_enc_mrr = 0.0
    else:
        cr_enc_mrr = 1 / cr_enc_counter

    return cr_enc_map, cr_enc_mrr, cr_enc_hit_list, cr_enc_hit_recall_list

In [None]:

print(df[df['Question Title'] == q]['answer'])
#print(df[df['Question Title'] == 'Handle redis cache availability with spring boot']['answer'])

In [None]:
# test on 1 sample

q = val_corpus['2281']
a = val_query_answer['2281']
hits = forward_pass(q)
b_map, b_mrr, _, _ = evaluate_bi_encoder(hits, a)
c_map, c_mrr, _, _ = evalaute_cr_encoder(hits, a)


# further checks
print(q)
print('Hits:')
for hit in hits[:5]:
    h_text = val_text[hit['corpus_id']]
    cos_sim = util.pytorch_cos_sim(
        bi_encoder.encode(q, convert_to_tensor=True), bi_encoder.encode(h_text, convert_to_tensor=True)
    )[0][0].cpu().numpy()
    cscore = cr_encoder.predict([q, h_text])
    print(f'{h_text}\nCos sim: Check: {cos_sim:.3f} Predicted: {hit["score"]:.3f}')
    print(f'Cross Encoder:\nCheck: {cscore:.3f} Predicted: {hit["cross_score"]:.3f}')
    print()

print('Answers:')
data_id = df[df['Question Title'] == q]['index'].values[0]
print(data_id)
answers = val_query_answer[str(data_id)]
for answer in answers[:5]:
    cos_sim = util.pytorch_cos_sim(
        bi_encoder.encode(q, convert_to_tensor=True), bi_encoder.encode(answer, convert_to_tensor=True)
    )[0][0].cpu().numpy()
    cscore = cr_encoder.predict([q, answer])
    print(f'{answer}\nCos sim: {cos_sim:.3f}')
    print(f'Cross Encoder:\nCheck: {cscore:.3f}')
    print()
    

In [None]:
bm25_scores = {'mrr': 0, 'map': 0}
bi_enc_scores = {'mrr': 0, 'map': 0, 'precision': [0] * 4, 'recall': [0] * 4}
cr_enc_scores = {'mrr': 0, 'map': 0, 'precision': [0] * 4, 'recall': [0] * 4}

for query_key, answers in tqdm(val_query_answer.items(), total=len(val_query_answer)):
    query= val_corpus[query_key]
    hits = forward_pass(query, top_k)
    bm25_map, bm25_mrr = evaluate_bm25(query, answers, top_k)
    b_map, b_mrr, b_hit, b_rec = evaluate_bi_encoder(hits, answers)
    c_map, c_mrr, c_hit, c_rec = evalaute_cr_encoder(hits, answers)
    
    tmp_precision = [0] * 4
    tmp_recall = [0] * 4
    for idx, n in enumerate([1, 3, 5, 10]):
        tmp_precision[idx] = sum(b_hit[:n]) / n
        tmp_recall[idx] = sum(b_rec[:n]) / len(answers)

    bi_enc_scores['precision'] = [x + y for (x, y) in zip(bi_enc_scores['precision'], tmp_precision)] 
    bi_enc_scores['recall'] = [x + y for (x, y) in zip(bi_enc_scores['recall'], tmp_recall)]

    tmp_precision = [0] * 4
    tmp_recall = [0] * 4
    for idx, n in enumerate([1, 3, 5, 10]):
        tmp_precision[idx] = sum(c_hit[:n]) / n
        tmp_recall[idx] = sum(c_rec[:n]) / len(answers)

    cr_enc_scores['precision'] = [x + y for (x, y) in zip(cr_enc_scores['precision'], tmp_precision)] 
    cr_enc_scores['recall'] = [x + y for (x, y) in zip(cr_enc_scores['recall'], tmp_recall)]

    bm25_scores['map'] += bm25_map
    bm25_scores['mrr'] += bm25_mrr

    bi_enc_scores['map'] += b_map
    bi_enc_scores['mrr'] += b_mrr

    cr_enc_scores['map'] += c_map
    cr_enc_scores['mrr'] += c_mrr

bm25_scores['map'] /= len(val_query_answer)
bm25_scores['mrr'] /= len(val_query_answer)
bi_enc_scores['map'] /= len(val_query_answer)
bi_enc_scores['mrr'] /= len(val_query_answer)
bi_enc_scores['precision'] = [x/len(val_query_answer) for x in bi_enc_scores['precision']]
bi_enc_scores['recall'] = [x/len(val_query_answer) for x in bi_enc_scores['recall']]
cr_enc_scores['map'] /= len(val_query_answer)
cr_enc_scores['mrr'] /= len(val_query_answer)
cr_enc_scores['precision'] = [x/len(val_query_answer) for x in cr_enc_scores['precision']]
cr_enc_scores['recall'] = [x/len(val_query_answer) for x in cr_enc_scores['recall']]

In [None]:
print(f'BM25:\n{json.dumps(bm25_scores, indent=2)}')
print(f'Bi-Encoder:\n{json.dumps(bi_enc_scores, indent=2)}')
print(f'Cross-Encoder:\n{json.dumps(cr_enc_scores, indent=2)}')