In [None]:
from sentence_transformers import SentenceTransformer, CrossEncoder, util, models
from sklearn.feature_extraction import _stop_words as stop_words
from tqdm.notebook import tqdm
from rank_bm25 import BM25Okapi

import torch
import string
import json
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords

import numpy as np
import pandas as pd

In [None]:
bi_enc_weights = '../models/bienc-exp7/'
cr_enc_weights = '../models/crenc-exp7/'
data_folder = 'generated5'
top_k = 50
use_base = False

In [None]:
df = pd.read_excel('../data/20231004_data.xlsx', index_col=0)
df.head(2)

In [None]:
english_stopwords = set(stopwords.words('english'))

def bm25_tokenizer(text):
  tokenized_doc = []
  for token in text.lower().split():
    token = token.strip(string.punctuation)

    if len(token) > 0 and token not in english_stopwords:
      tokenized_doc.append(token)
      
  return tokenized_doc

In [None]:
if use_base:
    word_embedding_model = models.Transformer('distilroberta-base', max_seq_length=350)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    bi_encoder = SentenceTransformer(modules=[word_embedding_model, pooling_model])
    cr_encoder = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-6')
else:
    bi_encoder = SentenceTransformer(bi_enc_weights)
    cr_encoder = CrossEncoder(cr_enc_weights)

In [None]:
with open('/Users/Documents/final-code-microservice-paper/experiments/test_passage_100.json', 'r') as f:
    val_passage = json.load(f)

with open('/Users/Documents/final-code-microservice-paper/experiments/test_corpus_100.json', 'r') as f:  # Note the _100 in the filename
    val_corpus = json.load(f)

val_query_answer = {}
for idx, rel in val_passage.items():
    query = val_corpus.get(idx)  
    if query:  
        
        answers = [val_corpus[str(p)] for p in rel if str(p) in val_corpus]
        if answers:  
            val_query_answer[query] = answers



In [None]:
val_text = list(val_corpus.values())
val_emb = bi_encoder.encode(val_text, show_progress_bar=True, convert_to_tensor=True)

In [None]:
from tqdm import tqdm
tokenized_corpus = []
for idx, passage in tqdm(val_corpus.items()):
    tokenized_corpus.append(bm25_tokenizer(passage))

In [None]:
# #Total Correcte Answeres
import warnings
warnings.filterwarnings("ignore", message="No positive class found in y_true, recall is set to one for all thresholds.")

def forward_pass_rerank(query, top_k=50):
    q_emb = bi_encoder.encode(query, convert_to_tensor=True)
    hits = util.semantic_search([q_emb], val_emb, top_k=top_k+1)[0]

    cross_inputs = []
    
    to_remove = -1
    for hit in hits:
        text = val_text[hit['corpus_id']]
        if query == text:
            to_remove = hits.index(hit)
        cross_inputs.append([query, text])
        
    cross_scores = cr_encoder.predict(cross_inputs)
    
    for idx in range(len(cross_scores)):
        hits[idx]['cross_score'] = cross_scores[idx]
        
    if to_remove != -1: del hits[to_remove]
    hits = hits[:top_k]

    return hits

total_correct = 0  

for (query_key, answers) in tqdm(val_query_answer.items(), total=len(val_query_answer)):
    q = query_key
    r = val_query_answer[query_key]
    hits = forward_pass_rerank(q)
    
    hit_info_list = []
    
    for hit in hits[:5]:
        h_text = val_text[hit['corpus_id']]
        hit_info = {
            'h_text': h_text,
            'is_correct': "Yes" if h_text in r else "No"
        }
        hit_info_list.append(hit_info)

    correct_count = 0
    
    for hit_info in hit_info_list:
        is_correct = hit_info['is_correct']
        
        if is_correct == "Yes":
            correct_count += 1
    
    total_correct += correct_count

print(f"Total correct answers: {total_correct}")


In [None]:
#Total Questions Answered Correctly
import warnings
warnings.filterwarnings("ignore", message="No positive class found in y_true, recall is set to one for all thresholds.")

def forward_pass_rerank(query, top_k=50):
    q_emb = bi_encoder.encode(query, convert_to_tensor=True)
    hits = util.semantic_search([q_emb], val_emb, top_k=top_k+1)[0]

    cross_inputs = []
    
    to_remove = -1
    for hit in hits:
        text = val_text[hit['corpus_id']]
        if query == text:
            to_remove = hits.index(hit)
        cross_inputs.append([query, text])
        
    cross_scores = cr_encoder.predict(cross_inputs)
    
    for idx in range(len(cross_scores)):
        hits[idx]['cross_score'] = cross_scores[idx]
        
    if to_remove != -1: del hits[to_remove]
    hits = hits[:top_k]

    return hits

total_questions_answered_correctly = 0  

for (query_key, answers) in tqdm(val_query_answer.items(), total=len(val_query_answer)):
    q = query_key
    r = val_query_answer[query_key]
    hits = forward_pass_rerank(q)
    
    hit_info_list = []
    
    for hit in hits[:5]:
        h_text = val_text[hit['corpus_id']]
        hit_info = {
            'h_text': h_text,
            'is_correct': "Yes" if h_text in r else "No"
        }
        hit_info_list.append(hit_info)

    question_answered_correctly = any(hit_info['is_correct'] == "Yes" for hit_info in hit_info_list)
    
    if question_answered_correctly:
        total_questions_answered_correctly += 1

print(f"Total Questions Answered Correctly: {total_questions_answered_correctly}")
