## 1. Install required libraries

In [None]:
!pip install underthesea
!pip install qdrant-client sentence-transformers

## 2. Upload data files

In [None]:
!unzip "TVPL_AI.zip" -d .

## 3. Preprocess the data

In [25]:
import os
import re
import json
from underthesea import sent_tokenize
import xml.etree.ElementTree as ET

data_dir = 'data'
output_dir = os.path.join(data_dir, 'processed')
os.makedirs(output_dir, exist_ok=True)
data_dir = 'TVPL_AI/data'

In [26]:
def repl(match):
    num = float(match.group(1))
    return f"{num * 1_000_000:,.0f} VNĐ".replace(",", ".")

def normalize_number(text):
    """
    Convert number formats 1.2m -> 1.200.000 VNĐ
    """
    return re.sub(r'(\d+\.?\d*)m\b', repl, text)

def clean_text(text):
    """
    Remove special characters, normalize spaces.
    """
    text = re.sub(r'[^\w\sÀ-ỹ.,:;\n]', '', text)  # Keep the period, line break
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

def extract_title(after_dieu):
    """
    Excerpt title and end position.
    Allow long titles, commas, line breaks.
    """
    m = re.search(r'(.+?)(?=\n\d+\.\s|\n\d+\.|\n|$)', after_dieu)
    if m:
        title = m.group(1).strip()
        end_pos = m.end()
    else:
        title = after_dieu.strip()
        end_pos = len(after_dieu)

    return title, end_pos

def split_articles(raw_text):
    """
    Split text into articles ONLY if Article X has a period.
    """
    pattern = r'(\[BOLD\]Điều\s+\d+\s*\.)'
    matches = list(re.finditer(pattern, raw_text, re.IGNORECASE))
    articles = {}

    for idx, match in enumerate(matches):
        start = match.start()
        end = matches[idx + 1].start() if idx + 1 < len(matches) else len(raw_text)
        article_text = raw_text[start:end].strip()

        # Get number only
        m = re.search(r'Điều\s+(\d+)', match.group(1), re.IGNORECASE)
        if not m:
            continue
        article_num = m.group(1)

        after_dieu = article_text[len(match.group(1)):].strip()
        title, title_end = extract_title(after_dieu)
        content = after_dieu[title_end:].strip()

        # If content is empty => skip
        if not content:
            print(f"Skip Article {article_num} because text is blank.")
            continue

        articles[f"dieu_{article_num}"] = {
            "title": title.replace("[BOLD]", "").strip(),
            "text": content.replace("[BOLD]", "").strip()
        }

    return articles

In [27]:
def read_docx(docx_path):
    """
    Read the DOCX content, distinguishing bold words.
    """
    import zipfile
    import xml.etree.ElementTree as ET

    with zipfile.ZipFile(docx_path) as z:
        xml_content = z.read('word/document.xml')

    tree = ET.XML(xml_content)
    ns = {'w': 'http://schemas.openxmlformats.org/wordprocessingml/2006/main'}

    paragraphs = []
    for para in tree.findall('.//w:p', ns):
        runs = []

        for run in para.findall('.//w:r', ns):
            # Check bold
            rPr = run.find('w:rPr', ns)
            is_bold = False
            if rPr is not None and rPr.find('w:b', ns) is not None:
                is_bold = True

            text_elem = run.find('w:t', ns)
            if text_elem is not None and text_elem.text:
                if is_bold:
                    runs.append(f"[BOLD]{text_elem.text}")
                else:
                    runs.append(text_elem.text)

        if runs:
            paragraphs.append(''.join(runs))

    return '\n'.join(paragraphs)

def process_files():
    all_files = [f for f in os.listdir(data_dir) if f.endswith('.docx')]
    print(f"Found {len(all_files)} file .docx in {data_dir}")

    for file_name in all_files:
        file_path = os.path.join(data_dir, file_name)
        raw_text = read_docx(file_path)

        # Section of the law
        result = split_articles(raw_text)
        print(f"{file_name}: Split {len(result)} law")

        # Save JSON to processed folder
        json_file = file_name.replace('.docx', '.json')
        json_path = os.path.join(output_dir, json_file)
        with open(json_path, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False, indent=2)

        print(f"Processed and saved: {json_path}")

# Run all process
process_files()

Found 5 file .docx in TVPL_AI/data
100_2015_QH13_296661.docx: Split 415 law
Processed and saved: data/processed/100_2015_QH13_296661.json
45_2019_QH14_333670.docx: Split 214 law
Processed and saved: data/processed/45_2019_QH14_333670.json
168_2024_ND-CP_619502.docx: Split 54 law
Processed and saved: data/processed/168_2024_ND-CP_619502.json
41_2024_QH15_557190.docx: Split 137 law
Processed and saved: data/processed/41_2024_QH15_557190.json
145_2020_ND-CP_459400.docx: Split 0 law
Processed and saved: data/processed/145_2020_ND-CP_459400.json


## 4. Save new json files

In [28]:
import os
import numpy as np
import json

data_dir = "data/processed"
all_files = [f for f in os.listdir(data_dir) if f.endswith('.json')]

texts = []
metadata = []

for file_name in all_files:
    file_path = os.path.join(data_dir, file_name)
    with open(file_path, encoding="utf-8") as f:
        laws = json.load(f)

    for key, val in laws.items():
        full_text = val["title"] + ". " + val["text"]
        texts.append(full_text)

        # Add file name to ID to avoid duplication
        unique_id = f"{file_name.replace('.json', '')}_{key}"

        metadata.append({
            "id": unique_id,
            "file": file_name,
            "title": val["title"],
            "text": val["text"]
        })

print(f"Total Number of Articles: {len(texts)}")

# Save metadata
with open("laws_metadata.json", "w", encoding="utf-8") as f:
    json.dump(metadata, f, ensure_ascii=False, indent=2)

print("Full metadata saved, ready to be encoded.")

Total Number of Articles: 820
Full metadata saved, ready to be encoded.


## 5. Encode the articles and save them to Qdrant Cloud

In [29]:
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import PointStruct, Distance, VectorParams
import json

# 1. Cloud connectivity
QDRANT_URL="https://f8972297-f3c1-4741-ac07-79b7dc248785.europe-west3-0.gcp.cloud.qdrant.io:6333"
QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.M4sXz1nwgx_ZsJAn_d2AJDzEeuWo_0hClPX2fx2avAw"
qdrant_client = QdrantClient(
    url=QDRANT_URL,
    api_key=QDRANT_API_KEY,
)

# 2. Create collection if not exist
if qdrant_client.collection_exists("laws_collection"):
    qdrant_client.delete_collection("laws_collection")

qdrant_client.create_collection(
    collection_name="laws_collection",
    vectors_config=VectorParams(size=768, distance=Distance.COSINE)
)

# 3. Load JSON
with open("laws_metadata.json", encoding="utf-8") as f:
    laws = json.load(f)

# 4. Encode the Law
model = SentenceTransformer("VoVanPhuc/sup-SimCSE-VietNamese-phobert-base", device='cuda')

texts = [law["title"] + ". " + law["text"] for law in laws]
embeddings = model.encode(
    texts,
    batch_size=32,
    convert_to_numpy=True,
    normalize_embeddings=True
)

# 5. Upsert into Qdrant
points = []
for idx, (vec, law) in enumerate(zip(embeddings, laws)):
    points.append(PointStruct(
        id=idx,
        vector=vec.tolist(),
        payload={
            "id": law["id"],
            "file": law.get("file", ""),
            "title": law["title"],
            "text": law["text"]
        }
    ))

qdrant_client.upsert(
    collection_name="laws_collection",
    points=points
)

print("Upsert the Rule into Qdrant.")

Batches:   0%|          | 0/26 [00:00<?, ?it/s]

Upsert the Rule into Qdrant.


## 6. Connect Qdrant cloud and look up the answers

In [30]:
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient

# Encode the question
user_query = "NLĐ bị sa thải có được trả lương hay không?"
model = SentenceTransformer("VoVanPhuc/sup-SimCSE-VietNamese-phobert-base", device='cuda')
query_vec = model.encode(user_query, normalize_embeddings=True)

# Conect the Qdrant cloud
qdrant_client = QdrantClient(
    url=QDRANT_URL,
    api_key=QDRANT_API_KEY,
)
results = qdrant_client.query_points(
    collection_name="laws_collection",
    query=query_vec.tolist(),
    limit=20,
    with_payload=True
).points

seen_titles = set()
unique_hits = []

for hit in results:
    title = hit.payload['title'].strip()

    if title in seen_titles:
        continue

    seen_titles.add(title)
    unique_hits.append(hit)

    if len(unique_hits) >= 5:
        break

for idx, hit in enumerate(unique_hits, 1):
    print(f"Top {idx}:")
    print(f"ID: {hit.payload['id']}")
    print(f"Title: {hit.payload['title']}")
    print(f"Score: {hit.score:.4f}")
    print("---")

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Top 1:
ID: 45_2019_QH14_333670_dieu_207
Title: Tiền lương và các quyền lợi hợp pháp khác của người lao động trong thời gian đình công
Score: 0.6874
---
Top 2:
ID: 45_2019_QH14_333670_dieu_90
Title: Tiền lương
Score: 0.6714
---
Top 3:
ID: 45_2019_QH14_333670_dieu_40
Title: Nghĩa vụ của người lao động khi đơn phương chấm dứt hợp đồng lao động trái pháp luật
Score: 0.6681
---
Top 4:
ID: 45_2019_QH14_333670_dieu_99
Title: Tiền lương ngừng việc
Score: 0.6635
---
Top 5:
ID: 41_2024_QH15_557190_dieu_31
Title: Căn cứ đóng bảo hiểm xã hội
Score: 0.6607
---


## 7. Search with large number of queries

In [31]:
# List of questions
queries = [
    "NLĐ bị sa thải có được trả lương hay không?",
    "Người sử dụng lao động được sa thải người lao động nữ đang mang thai không?",
    "Quy định về điều chuyển nhân sự được quy định như thế nào?",
    "Người lao động được thuê làm giám đốc doanh nghiệp Nhà nước được hưởng các chế độ về tiền lương, thưởng như thế nào?",
    "Làm việc 8h một ngày thì được nghỉ giữa giờ ít nhất bao nhiêu phút?",
    "Người sử dụng lao động đào tạo nghề nghiệp và phát triển kỹ năng nghề cho người lao động như thế nào?",
    "Nguyên tắc cho thuê lại lao động là gì?",
    "Thời hạn của thỏa ước lao động tập thể như thế nào?",
    "Hợp đồng lao động được giao kết theo hình thức nào?",
    "Nội dung về đào tạo lao động có bắt buộc phải ghi vào hợp đồng lao động?"
]

with open("search_results.txt", "w", encoding="utf-8") as f:
    cnt = 1
    for user_query in queries:
        # Encode the question
        query_vec = model.encode(user_query, normalize_embeddings=True)

        # Query Qdrant to get results
        hits = qdrant_client.query_points(
            collection_name="laws_collection",
            query=query_vec.tolist(),
            limit=20,
            with_payload=True
        ).points

        seen_titles = set()
        arti_cnt = 1

        f.write(f"Query {cnt}: {user_query}\n")

        for hit in hits:
            title = hit.payload['title'].strip()

            if title in seen_titles:
                continue

            seen_titles.add(title)

            f.write(f"Article {arti_cnt}: {title}\n")
            f.write("---\n")
            arti_cnt += 1

            # Stop if there are 5 unique results
            if arti_cnt > 5:
                break

        f.write("\n\n")
        cnt += 1

print("Results saved to file 'search_results.txt'.")

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Results saved to file 'search_results.txt'.


## 8. Get accuracy on a test set (10 pairs <question - correct rule>)

In [None]:
queries = []
ground_truth_ids = []

correct = 0

for i, query in enumerate(queries):
    query_vec = model.encode(query, normalize_embeddings=True)
    hits = qdrant_client.query_points(
        collection_name="laws_collection",
        query=query_vec.tolist(),
        limit=5,   # Top K
        with_payload=True
    ).points

    top_ids = [hit.payload['id'] for hit in hits]
    if ground_truth_ids[i] in top_ids:
        correct += 1

accuracy = correct / len(queries)
print(f"Accuracy@5: {accuracy:.2%}")