In [None]:
!pip install sentence-transformers faiss-cpu transformers datasets tqdm

Collecting faiss-cpu
  Downloading faiss_cpu-1.12.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Downloading faiss_cpu-1.12.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (31.4 MB)
[2K   [90mโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ[0m [32m31.4/31.4 MB[0m [31m82.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.12.0


In [1]:
import json
import os
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def extract_articles_v2(dataset):
    articles = []
    for category, subcats in dataset.items():
        for subcat, systems in subcats.items():
            for system_name, system_data in systems.items():
                brief = system_data.get("brief", "")
                metadata = system_data.get("metadata", {})
                parts = system_data.get("parts", {})
                for part_name, part_articles in parts.items():
                    for article in part_articles:
                        articles.append({
                            "category": category,
                            "sub_category": subcat,
                            "system": system_name,
                            "part": part_name,
                            "brief": brief,
                            "metadata": metadata,
                            "id": article.get("id"),
                            "title": article.get("Article_Title"),
                            "status": article.get("status"),
                            "text": article.get("Article_Text")
                        })
    return articles


In [3]:
with open("data/saudi_laws_scraped.json", "r", encoding="utf-8") as f:
    data = json.load(f)

articles = extract_articles_v2(data)
print(f"โ Total Articles Extracted: {len(articles)}")
print(json.dumps(articles[604], indent=2, ensure_ascii=False))

โ Total Articles Extracted: 16371
{
  "category": "ุฃูุธูุฉ ุนุงุฏูุฉ",
  "sub_category": "ุงูุฃูู ุงูุฏุงุฎูู ูุงูุฃุญูุงู ุงููุฏููุฉ ูุงูุฃูุธูุฉ ุงูุฌูุงุฆูุฉ",
  "system": "ูุธุงู ููุงูุญุฉ ุบุณู ุงูุฃููุงู",
  "part": "main",
  "brief": "ูุชุถูู ุงููุธุงู:\r\nุงูููุตูุฏ ุจุงูุนุจุงุฑุงุช ูุงูุฃููุงุธ ุงููุงุฑุฏุฉ ุจุงููุธุงู. ุงูุฃูุนุงู ุงูุชู ูุนุฏ ูุฑุชูุจูุง ูุฑุชูุจูุง ุฌุฑููุฉ ุบุณู ุงูุฃููุงู โ ูุง ูุฌุจ ุนูู ุงููุคุณุณุงุช ุงููุงููุฉ ูุบูุฑ ุงููุงููุฉ ุงุชุฎุงุฐู ูู ุฅุฌุฑุงุกุงุช ุญูุงู ูุฑุชูุจ ุฌุฑููุฉ ุบุณู ุงูุฃููุงู โ ุงูุจุฑุงูุฌ ุงูุชู ุชุถุนูุง ุงููุคุณุณุงุช ุงููุงููุฉ ูุบูุฑ ุงููุงููุฉ ูููุงูุญุฉ ุนูููุงุช ุบุณู ุงูุฃููุงู โ ูุญุฏุฉ ููุงูุญุฉ ุบุณู ุงูุฃููุงู โ ุนููุจุฉ ูุฑุชูุจ ุฌุฑููุฉ ุบุณู ุงูุฃููุงู.",
  "metadata": {
    "ุงูุงุณู": "ูุธุงู ููุงูุญุฉ ุบุณู ุงูุฃููุงู",
    "ุชุงุฑูุฎ 

In [4]:
[article['text'] for i, article in  enumerate(articles) if article['id'] == 2]

['ูููู\nุนูู ุงูุฏููุฉ\nููุง ููู :\nุฃ  - ูููู ุฃุฎุถุฑ.\nุจ - ุนุฑุถู ูุณุงูู ุซูุซู ุทููู.\nุฌ - ุชุชูุณุทู ูููุฉ : (ูุง ุฅูู ุฅูุง ุงููู ูุญูุฏ ุฑุณูู ุงููู) ุชุญุชูุง ุณูู ูุณูููุ ููุง ูููุณ ุงูุนูู ุฃุจุฏุง.\nููุจูู  ุงููุธุงู  ุงูุฃุญูุงู ุงููุชุนููุฉ ุจู.']

In [5]:
def build_corpus(articles):
    corpus = []
    for art in articles:
        title = art.get("title", "").strip()
        brief = art.get("brief", "").strip()
        text = art.get("text", "").strip()
        
        # Format metadata as "key: value" pairs
        meta = art.get("metadata", {})
        meta_str = " ".join(f"{k}: {v}" for k, v in meta.items() if v)

        # Combine elements with clean formatting
        parts = [
            f"Law Title: {title}" if title else "",
            f"Law Brief: {brief}" if brief else "",
            f"Law Text: {text}" if text else "",
            f"Law Metadata: {meta_str}" if meta_str else "",
        ]

        # Filter out empty parts and join with double newlines for clarity
        entry = "\n\n".join(filter(None, parts)).strip()
        corpus.append(entry)

    return corpus


corpus = build_corpus(articles)
print(f"โ Corpus built with {len(corpus)} documents")
for i in range(2):
    print(f"\n--- Example {i+1} ---\n{corpus[i][:400]}...")

โ Corpus built with 16371 documents

--- Example 1 ---
Law Title: ุงููุงุฏุฉ ุงูุฃููู

Law Brief: ูุชุถูู ุงูุนูุงููู ุงูุชุงููุฉ: ุงููุจุงุฏุฆ ุงูุนุงูุฉุ ูุธุงู ุงูุญููุ ููููุงุช ุงููุฌุชูุน ุงูุณุนูุฏูุ ุงููุจุงุฏุฆ ุงูุงูุชุตุงุฏูุฉุ ุงูุญููู ูุงููุงุฌุจุงุชุ ุณูุทุงุช ุงูุฏููุฉุ ุงูุดุฆูู ุงููุงููุฉุ ุฃุญูุงู ุนุงูุฉ.

Law Text: ุงูููููุฉ ุงูุนุฑุจูุฉ ุงูุณุนูุฏูุฉุ ุฏููุฉ ุนุฑุจูุฉ ุฅุณูุงููุฉุ ุฐุงุช
ุณูุงุฏุฉ ุชุงูุฉ
ุ ุฏูููุง
ุงูุฅุณูุงู
ุ ูุฏุณุชูุฑูุง
ูุชุงุจ ุงููู ุชุนุงูู
ูุณูุฉ ุฑุณููู ุตูู ุงููู ุนููู ูุณูู. ููุบุชูุง ูู ุงููุบุฉ ุงูุนุฑุจูุฉุ ูุนุงุตูุชูุง ูุฏููุฉ ุงูุฑูุงุถ.

Law Metadata...

--- Example 2 ---
Law Title: ุงููุงุฏุฉ ุงูุซุงููุฉ

Law Brief: ูุชุถูู ุงูุนูุงููู ุงูุชุงููุฉ: ุงููุจุงุฏุฆ ุงูุนุงูุฉุ ูุธุงู ุงูุญููุ ููููุงุช ุงููุฌุชูุน ุงูุณุนูุฏูุ ุงููุจุงุฏุฆ ุงูุงูุชุตุงุฏ

In [7]:
import torch
embed_model = SentenceTransformer("BAAI/bge-m3", device="cuda" if torch.cuda.is_available() else "cpu")

In [18]:
def embed_corpus(corpus, embed_model, batch_size=128):
    num_items = len(corpus)
    dim = embed_model.get_sentence_embedding_dimension()
    embeddings = np.zeros((num_items, dim), dtype=np.float32)

    for start in tqdm(range(0, num_items, batch_size)):
        end = min(start + batch_size, num_items)
        batch = corpus[start:end]
        embeddings[start:end] = embed_model.encode(batch, show_progress_bar=False, convert_to_numpy=True, normalize_embeddings=True)

    return embeddings

def save_faiss_index(embeddings, index_path="m3_legal_faiss.index"):
    dim = embeddings.shape[1]
    base_index = faiss.IndexFlatIP(dim)
    index = faiss.IndexIDMap(base_index)
    ids = np.arange(embeddings.shape[0])
    index.add_with_ids(embeddings, ids)
    faiss.write_index(index, index_path)
    return index


In [None]:
embed_corpus(corpus, embed_model)
embeddings = embed_corpus(corpus, embed_model)
print(f"โ Embeddings shape: {embeddings.shape}")
index = save_faiss_index(embeddings, index_path="legal_faiss_brief.index")
print("โ FAISS index saved as 'legal_faiss_brief.index'")

In [None]:
filtered_indices = np.array([0, 1,2, 5 ], dtype=np.int64)
selector = faiss.IDSelectorArray(filtered_indices)
def retrieve(query, top_k=5):
    q_emb = embed_model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
    D, I = index.search(q_emb, top_k, params=faiss.SearchParameters(sel=selector)
)
    results = [(i, float(D[0][j])) for j, i in enumerate(I[0])]
    return results
query = "ูุง ูู ุดุนุงุฑ ุงูุฏููุฉ ุงูุณุนูุฏูุฉุ"
results = retrieve(query)
for i, (idx, score) in enumerate(results, 1):
    print(f"\n๐น Result ID {idx} (score={score:.3f})\n{corpus[idx][:400]}...")


๐น Result ID 0 (score=0.578)
ุงููุงุฏุฉ ุงูุฃููู - ุงูููููุฉ ุงูุนุฑุจูุฉ ุงูุณุนูุฏูุฉุ ุฏููุฉ ุนุฑุจูุฉ ุฅุณูุงููุฉุ ุฐุงุช
ุณูุงุฏุฉ ุชุงูุฉ
ุ ุฏูููุง
ุงูุฅุณูุงู
ุ ูุฏุณุชูุฑูุง
ูุชุงุจ ุงููู ุชุนุงูู
ูุณูุฉ ุฑุณููู ุตูู ุงููู ุนููู ูุณูู. ููุบุชูุง ูู ุงููุบุฉ ุงูุนุฑุจูุฉุ ูุนุงุตูุชูุง ูุฏููุฉ ุงูุฑูุงุถ. ุงูุงุณู: ุงููุธุงู ุงูุฃุณุงุณู ููุญูู ุชุงุฑูุฎ ุงูุฅุตุฏุงุฑ: 1412/08/27 ูู  ุงูููุงูู : 01/03/1992 ูู ุชุงุฑูุฎ ุงููุดุฑ: 1412/09/02  ูู ุงูููุงูู : 06/03/1992 ูู ุงูุญุงูุฉ: ุณุงุฑู ุฃุฏูุงุช ุฅุตุฏุงุฑ ุงููุธุงู: [{'text': 'ุฃูุฑ ูููู ุฑูู ุฃ/90 ุจุชุงุฑู...

๐น Result ID 2 (score=0.562)
ุงููุงุฏุฉ ุงูุซุงูุซุฉ - ูููู
ุนูู ุงูุฏููุฉ
ููุง ููู :
ุฃ  - ูููู ุฃุฎุถุฑ.
ุจ - ุนุฑุถู ูุณุงูู ุซูุซู ุทููู.
ุฌ - ุชุชูุณุทู ูููุฉ : (ูุง ุฅูู ุฅูุง ุงููู ูุญูุฏ ุฑุณูู ุงููู) ุชุญุชูุง ุณูู ูุณูููุ

In [24]:
def build_law_corpus(data):
    laws_text = []
    # Traverse the 3-level hierarchy
    for main_cat_name, sub_categories in data.items():
        for sub_cat_name, laws in sub_categories.items():
            for law_title, law_data in laws.items():
                laws_text.append("Law Name:"+ law_title+"\nLaw Summary:"+law_data['brief'] )
                
    return laws_text

laws_corpus = build_law_corpus(data)
print(f"โ Laws corpus built with {len(laws_corpus)} entries")
laws_corpus[0]

โ Laws corpus built with 517 entries


'Law Name:ุงููุธุงู ุงูุฃุณุงุณู ููุญูู\nLaw Summary:ูุชุถูู ุงูุนูุงููู ุงูุชุงููุฉ: ุงููุจุงุฏุฆ ุงูุนุงูุฉุ ูุธุงู ุงูุญููุ ููููุงุช ุงููุฌุชูุน ุงูุณุนูุฏูุ ุงููุจุงุฏุฆ ุงูุงูุชุตุงุฏูุฉุ ุงูุญููู ูุงููุงุฌุจุงุชุ ุณูุทุงุช ุงูุฏููุฉุ ุงูุดุฆูู ุงููุงููุฉุ ุฃุญูุงู ุนุงูุฉ.'

In [25]:
embeddings = embed_corpus(laws_corpus, embed_model)
print(f"โ Laws corpus embeddings shape: {embeddings.shape}")

100%|โโโโโโโโโโ| 5/5 [00:03<00:00,  1.54it/s]

โ Laws corpus embeddings shape: (517, 1024)





In [None]:
law_index = save_faiss_index(embeddings, index_path="laws_legal_faiss.index")

In [15]:
def build_parts_corpus(data, max_tokens=512, chunk_overlap=50, tokenizer_func=None):
    def count_tokens(text):
        """Simple token counting - you might want to use a proper tokenizer"""
        return len(tokenizer_func(text, add_special_tokens=False))
    
    def chunk_text(text, max_tokens, overlap):
        """Split text into chunks based on token count while preserving context"""
        sentences = text.split('. ')
        chunks = []
        current_chunk = []
        current_token_count = 0
        
        for sentence in sentences:
            sentence_tokens = count_tokens(sentence)
            
            # If adding this sentence would exceed the limit
            if current_token_count + sentence_tokens > max_tokens and current_chunk:
                # Save current chunk
                chunks.append('. '.join(current_chunk).strip())
                
                # Start new chunk with overlap
                if overlap > 0:
                    # Calculate how many sentences to keep for overlap
                    overlap_text = '. '.join(current_chunk[-3:]).strip()  # Rough overlap
                    current_chunk = [overlap_text] if overlap_text else []
                    current_token_count = count_tokens(overlap_text) if overlap_text else 0
                else:
                    current_chunk = []
                    current_token_count = 0
            
            current_chunk.append(sentence)
            current_token_count += sentence_tokens
        
        # Add the last chunk if it has content
        if current_chunk:
            chunks.append('. '.join(current_chunk).strip())
        
        return chunks

    def format_article(article):
        title = article.get("Article_Title", "").strip()
        text = article.get("Article_Text", "").strip()
        parts = []
        if title:
            parts.append(f"Article Title: {title}")
        if text:
            parts.append(f"Article Text: {text}")
        return "\n".join(parts)
    
    corpus = []  # Will store dicts instead of plain text
    id_to_part_id = {}  
    chunk_id = 0
    # Traverse the 3-level hierarchy
    for main_cat_name, sub_categories in data.items():
        for sub_cat_name, laws in sub_categories.items():
            for law_title, law_data in laws.items():
                parts = law_data.get("parts", {})
                brief = law_data.get("brief", "")
                law_text = "Law Title: " + law_title + "\n" + "Law Brief: " + brief
                
                for part_name, part_articles in parts.items():
                    # Build the part context
                    if part_name != "main":
                        context_lines = [law_text, f"Part Name: {part_name}"]
                    else:
                        context_lines = [law_text]
                    
                    context_lines.extend([format_article(article) for article in part_articles])
                    full_text = "\n".join(context_lines)
                    
                    part_id = law_title + "|" + part_name 
                   
                    # Check token count
                    token_count = count_tokens(full_text)
                    if token_count > max_tokens:
                        # Chunk the text and track chunk indices
                        chunked_texts = chunk_text(full_text, max_tokens, chunk_overlap)
                        
                        for idx, chunk in enumerate(chunked_texts):
                            # Add chunk to corpus with unique ID
                            corpus.append(chunk)
                            id_to_part_id[chunk_id] = part_id
                            chunk_id += 1
                    else:
                        # Create unique ID for non-chunked part
                        corpus.append(full_text)
                        id_to_part_id[chunk_id] = part_id
                        chunk_id += 1
                        


    return {
        "corpus": corpus,  
        "corpus_id_to_part_id": id_to_part_id  
    }

In [34]:
embed_model.tokenizer.model_max_length

8192

In [16]:
parts_data = build_parts_corpus(data, max_tokens=8192, chunk_overlap=50, tokenizer_func=embed_model.tokenizer.encode)
corpus = parts_data['corpus']
print(f"โ Parts corpus built with {len(corpus)} entries")


โ Parts corpus built with 2050 entries


In [17]:
len(parts_data['corpus_id_to_part_id'])

2050

In [40]:
embeddings = embed_corpus(corpus, embed_model, batch_size=16)
print(f"โ Parts corpus embeddings shape: {embeddings.shape}")
index = save_faiss_index(embeddings, index_path="m3_legal_faiss_parts.index")

100%|โโโโโโโโโโ| 129/129 [12:01<00:00,  5.60s/it]

โ Parts corpus embeddings shape: (2050, 1024)





In [19]:
json.dump(parts_data['corpus_id_to_part_id'], open("m3_corpus_id_to_part_id_parts.json", "w", encoding="utf-8"), ensure_ascii=False, indent=2)