In [None]:
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizerFast,
    DataCollatorWithPadding,
    BatchEncoding
)
from transformers import AutoModel, AutoTokenizer, BatchEncoding, PreTrainedTokenizerFast
from transformers.modeling_outputs import BaseModelOutput
from typing import List
from sklearn.metrics import ndcg_score
from tqdm import tqdm
import pandas as pd
import numpy as np
import json
import torch
import pickle

In [None]:
model_name_or_path="checkpoints/biencoder_2023-09-18-1901.46"
legal_corpus_path="data/sent_truncated_vbpl_update.csv"
test_path="data/dvc_test.json"
index_path="me5_small_0919.index"
q_max_length=32
p_max_length=144

In [37]:
# Đọc dữ liệu các văn bản
vbpl = pd.read_csv(legal_corpus_path, encoding='utf-16')
vbpl = vbpl.dropna().reset_index(drop=True)
passages: List[str] = vbpl['truncated_text'].astype(str).tolist()

In [39]:
# Đọc dữ liệu test
with open(test_path, 'r') as file:
    train = json.load(file)

In [40]:
model = AutoModel.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

In [57]:
def l2_normalize(x: torch.Tensor):
    return torch.nn.functional.normalize(x, p=2, dim=-1)

def encode_query(tokenizer: PreTrainedTokenizerFast, query: str) -> BatchEncoding:
    return tokenizer(query,
                     max_length=p_max_length,
                     padding=True,
                     truncation=True,
                     return_tensors='pt')

def encode_passage(tokenizer: PreTrainedTokenizerFast, passage: str, title: str = "") -> BatchEncoding:
    return tokenizer(title,
                     text_pair=passage,
                     max_length=144,
                     padding=True,
                     truncation=True,
                     return_tensors='pt')

In [None]:
psg_embeddings = []
for passage in tqdm(passages):
    psg_batch_dict = encode_passage(tokenizer, passage)
    outputs: BaseModelOutput = model(**psg_batch_dict, return_dict=True)
    psg_embeddings.append(l2_normalize(outputs.last_hidden_state[0, 0, :]))

In [None]:
def normalize_embeddings(embeddings):
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    normalized_embeddings = embeddings / norms
    return normalized_embeddings

norm_psg_embeddings = normalize_embeddings(psg_embeddings)

In [None]:
index = faiss.IndexFlatIP(model.get_sentence_embedding_dimension())
index.add(norm_psg_embeddings)

In [None]:
faiss.write_index(index, index_path)

In [None]:
import faiss
index = faiss.read_index(index_path)

In [None]:
def calculate_f2(precision, recall):        
    return (5 * precision * recall) / (4 * precision + recall + 1e-20)
def calculate_f1(precision, recall):        
    return (precision * recall) / (precision + recall + 1e-20)

In [None]:
# Đánh giá với dữ liệu
top_k = 500
thresh = 0
total_f1 = 0
total_f2 = 0
total_precision = 0
total_recall = 0
total_map = 0
total_ndcg = 0

# Check
precisions = []
recalls = []

for i, item in tqdm(enumerate(test)):
    query = item["noi_dung_hoi"]
#     query = query[:-1] if query.endswith("?") else query
    relevant_articles = item["vb_lien_quan"]
    dict_relevant = {(article["so_hieu"], article["dieu"]) : article for article in relevant_articles}
    actual_positive = len(relevant_articles)
    
    query_batch_dict = encode_query(tokenizer, query)
    outputs: BaseModelOutput = model(**query_batch_dict, return_dict=True)
    query_embedding = l2_normalize(outputs.last_hidden_state[0, 0, :])
    normalized_query_embedding = query_embedding / np.linalg.norm(query_embedding)
    scores, indices, embeddings = index.search_and_reconstruct(normalized_query_embedding, top_k)
    hits = vbpl.iloc[indices[0]]
    hits1 = hits.copy()
    hits1['score'] = scores[0]
    
    true_positive_set = set()
    false_positive_set = set()
    num_hits = 0
    average_precision = 0
    actual_relevance = []
      
    for j, idx_pred in enumerate(hits.index, 1):
        key = (hits.at[idx_pred, "so_hieu"], hits.at[idx_pred, "dieu"])
        if key in dict_relevant:
            actual_relevance.append(1)
            true_positive_set.add(key)
            num_hits += 1
            average_precision += num_hits/j
        else:
            actual_relevance.append(0)
            false_positive_set.add(key)

    true_positive = len(true_positive_set)            
    false_positive = len(false_positive_set)
    
    if num_hits != 0: 
        average_precision = average_precision/num_hits
    
    ndcg = ndcg_score([actual_relevance], [scores[0]], k=top_k)
    
    precision = true_positive/(true_positive + false_positive + 1e-20)
    recall = true_positive/actual_positive
    f1 = calculate_f1(precision, recall)
    f2 = calculate_f2(precision, recall)
    
    total_precision += precision
    total_recall += recall
    total_f1 += f1
    total_f2 += f2
    total_map += average_precision
    total_ndcg += ndcg
    
    precisions.append(precision)
    recalls.append(recall)

In [None]:
N = len(test)
print(f"Recall: {total_recall/N}")
print(f"Precision: {total_precision/N}")
print(f"F2: {total_f2/N}")
print(f"F1: {total_f1/N}")
print(f"MAP: {total_map/N}")
print(f"NDCG: {total_ndcg/N}")