In [25]:
# Cell 1: Importing necessary libraries
from datasets import load_dataset
import unicodedata as ud
import re
from tqdm import tqdm
from rank_bm25 import BM25Okapi
import string
import pandas as pd
import json
from underthesea import word_tokenize

# Stopwords tiếng Việt
stop_words_vn = set([
    "của", "và", "về", "trong", "được", "là", "các", "bởi", "để", "này",
    "theo", "một", "hoặc", "với", "tại", "khi", "thì", "nếu", "mà", "đã",
    "những", "có", "không", "trên", "dưới", "ra", "vẫn", "lại", "rất",
    "cũng", "như", "bằng", "từ", "sẽ", "phải", "giữa", "qua", "từng", 
    "thông", "báo", "căn", "cứ", "này", "phạm", "vi", "chung", "áp",
    "dụng", "chỉ", "thứ", "trách", "nhiệm", "hữu", "hạn", "công", "ty",
    "cổ", "phần", "hợp", "doanh", "nghiệp", "tư", "nhân", "bao", "gồm",
    "gọi", "tên", "sau", "cơ", "quan", "tổ", "chức", "hoạt", "động",
    "liên", "quan", "thành", "lập", "lại", "giải", "thể", "quy", "định",
    "quyền", "nghĩa", "vụ", "bản", "sao", "giấy", "tờ", "sổ", "chính",
    "thẩm", "đối", "chiếu", "cá", "nhân", "nước", "ngoài", "mang",
    "người", "nhà", "địa", "số", "luật", "pháp", "việc", "điều", "khoản"
])


In [26]:
# Cell 2: Loading the dataset
meta_corpus = load_dataset(
    "json",
    data_files="/Users/nhotin/Documents/GitHub/LegalBizAI_project/test_set/id_cof/chunk_sz_fl_point/all_chunk_final.json",
    split="train"
).to_list()

In [27]:
# Cell 3: Function to split text
def split_text(text):
    text = text.translate(str.maketrans('', '', string.punctuation))
    words = text.lower().split()
    words = [word for word in words if len(word.strip()) > 0]
    return words

In [28]:
# Cell 4: Function to retrieve relevant chunks using BM25
def retrieve(question, topk=50):
    tokenized_query = split_text(question)
    bm25_scores = bm25.get_scores(tokenized_query)
    corpus_size = len(meta_corpus)
    for i in range(corpus_size):
        meta_corpus[i]["bm25_score"] = bm25_scores[i]
    bm25_passages = sorted(meta_corpus, key=lambda x: x["bm25_score"], reverse=True)
    return bm25_passages[:topk]

In [29]:
# Cell 5: Initiate BM25 retriever with parameter tuning
tokenized_corpus = [split_text(doc["passage"]) for doc in tqdm(meta_corpus)]
bm25 = BM25Okapi(tokenized_corpus, k1=2.0, b=0.75)  # Adjust parameters k1 and b

100%|██████████| 4162/4162 [00:00<00:00, 33173.38it/s]


In [30]:
# Cell 6: Function to get the top similar chunks
def getTopSimi(top_retrive):
    ids = []
    best_retrive = []
    score = []
    for each in top_retrive:
        score.append(each["bm25_score"])
    avg_score = sum(score)/len(score)
    for each in top_retrive:
        if each["bm25_score"] > avg_score:
            best_retrive.append(each)
            ids.append(each["id"])
    ret = dict()
    ret["copus"] = best_retrive
    ret["ids"] = ids
    return ret

In [31]:
# Cell 7: Function to get the full article passage
def get_full_article(chunks: list[dict], chunk_ids: list[int]) -> dict:
    articles_ids = set()
    for chunk_id in chunk_ids:
        if chunk_id in articles_ids:
            continue
        articles_ids.add(chunk_id)
        chunk_title = chunks[chunk_id]["title"]
        run_id = chunk_id - 1
        while run_id >= 0 and chunks[run_id]["title"] == chunk_title:
            articles_ids.add(run_id)
            run_id -= 1
        run_id = chunk_id + 1
        while run_id < len(chunks) and chunks[run_id]["title"] == chunk_title:
            articles_ids.add(run_id)
            run_id += 1
    articles_ids = sorted(articles_ids)
    content_lines = []
    chunk_title = ""
    for id in articles_ids:
        if chunk_title != chunks[id]["title"]:
            chunk_title = chunks[id]["title"]
            content_lines.append(chunk_title)
        passage_lines = chunks[id]["passage"].splitlines()
        content_lines.extend(passage_lines[1:])
    content = "\n".join(content_lines)
    return {"ids": articles_ids, "content": content}

In [32]:
# Cell 8: Function to calculate F1 beta score
def f1_beta(pred, actual, beta=4):
    TP, FP, FN = 0, 0, 0
    for pred_list, actual_list in zip(pred, actual):
        pred_set = set(pred_list)
        actual_set = set(actual_list)
        TP += len(pred_set & actual_set)
        FP += len(pred_set - actual_set)
        FN += len(actual_set - pred_set)
    precision = TP / (TP + FP) if (TP + FP) > 0 else 0
    recall = TP / (TP + FN) if (TP + FN) > 0 else 0
    if precision + recall > 0:
        f1_beta_score = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall)
    else:
        f1_beta_score = 0
    score = dict()
    score["recall"] = recall
    score["precision"] = precision
    score["f1_beta"] = f1_beta_score
    return score

In [33]:
# Cell 9: Load and process the question-answer dataset
df = pd.read_json("/Users/nhotin/Documents/GitHub/LegalBizAI_project/test_set/id_cof/chunk_sz_fl_point/qasetfinal.json")
df = df[["question", "chunk_ids"]]

In [34]:
# Cell 10: Load all chunks
filepath = "/Users/nhotin/Documents/GitHub/LegalBizAI_project/test_set/id_cof/chunk_sz_fl_point/all_chunk_final.json"
with open(filepath, "r", encoding="utf-8") as f:
    all_chunks = json.load(f)

In [35]:
# Cell 11: Function to get retrieval ids for a question
def retrieval_ids(question):
    def preprocessing(promt):
        question= promt.strip().lower()
        question = re.sub(r'[^\w\s]', '', question)
        words = question.split()
        words = [word for word in words if word not in stop_words_vn]
        return " ".join(words)
    top_chunk = retrieve(preprocessing(question), topk=3)
    best_chunk_ids = getTopSimi(top_chunk)["ids"]
    return get_full_article(all_chunks, best_chunk_ids)["ids"]

In [36]:
# Cell 12: Apply retrieval_ids function and evaluate the model
df["pred_ids"] = df["question"].apply(retrieval_ids)
actual = df["chunk_ids"].tolist()
pred = df["pred_ids"].to_list()
print(f1_beta(pred, actual))

{'recall': 0.47810713343301964, 'precision': 0.3860566201224449, 'f1_beta': 0.4714940724692552}


In [37]:
# Cell 13: Display the dataframe and compare predictions with actual chunk ids
df
df[df["chunk_ids"] == df["pred_ids"]]

Unnamed: 0,question,chunk_ids,pred_ids
4,Thứ tự thanh toán khoản nợ của doanh nghiệp gi...,"[1683, 1684, 1685, 1686, 1687, 1688, 1689, 169...","[1683, 1684, 1685, 1686, 1687, 1688, 1689, 169..."
6,Các khoản nợ của doanh nghiệp tư nhân giải thể...,"[1683, 1684, 1685, 1686, 1687, 1688, 1689, 169...","[1683, 1684, 1685, 1686, 1687, 1688, 1689, 169..."
7,Tình trạng đã phá sản của doanh nghiệp là gì?,"[2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021]","[2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021]"
9,Hội đồng giải thể doanh nghiệp do Nhà nước nắm...,"[2799, 2800, 2801, 2802, 2803, 2804, 2805, 280...","[2799, 2800, 2801, 2802, 2803, 2804, 2805, 280..."
11,Doanh nghiệp do Nhà nước nắm giữ 100% vốn điều...,"[2763, 2764, 2765, 2766, 2767, 2768, 2769, 277...","[2763, 2764, 2765, 2766, 2767, 2768, 2769, 277..."
...,...,...,...
5547,Trường hợp nào sẽ tiến hành bầu dồn phiếu?,"[1158, 1159, 1160, 1161, 1162, 1163, 1164, 116...","[1158, 1159, 1160, 1161, 1162, 1163, 1164, 116..."
5551,Vốn có quyền biểu quyết là gì?,"[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,...","[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,..."
5555,Quyền được cung cấp thông tin của Ban kiểm soá...,"[1414, 1415, 1416, 1417, 1418, 1419]","[1414, 1415, 1416, 1417, 1418, 1419]"
5559,Điều kiện để công ty TNHH 2 thành viên trở lên...,[548],[548]


In [38]:
# Cell 14: Experiment with different top-k values
for topk in range(5, 21, 5):  
    def retrieval_ids_experiment(question):
        top_chunk = retrieve(question, topk=topk)
        best_chunk_ids = getTopSimi(top_chunk)["ids"]
        return get_full_article(all_chunks, best_chunk_ids)["ids"]

    df["pred_ids"] = df["question"].apply(retrieval_ids_experiment)
    actual = df["chunk_ids"].tolist()
    pred = df["pred_ids"].to_list()
    scores = f1_beta(pred, actual)
    print(f"Top-k: {topk}, Recall: {scores['recall']}, Precision: {scores['precision']}, F1-beta: {scores['f1_beta']}")


Top-k: 5, Recall: 0.6387545078722843, Precision: 0.4082298049356344, F1-beta: 0.6182189661129834
Top-k: 10, Recall: 0.7219456416571378, Precision: 0.2900528666742056, F1-beta: 0.6638037008281573
Top-k: 15, Recall: 0.7557216993579031, Precision: 0.22761785167405807, F1-beta: 0.6649679081231739
Top-k: 20, Recall: 0.7801917494942388, Precision: 0.18732998800412254, F1-beta: 0.6577432153737173
