# Amharic Legal RAG System

This notebook builds a retrieval-augmented system for Amharic legal documents. 

**Purpose**: Provide contextual grounding for legal text simplification, NOT generation.

**Key Features**:
- Article-level chunking (አንቀጽ)
- Gemini embeddings for Amharic
- Rich legal metadata
- Retrieval-only (no generation)


# 1. Environment Setup


In [29]:
import os
import sys
from pathlib import Path

# Set working directory to project root
WORKDIR = Path("/Users/blank/Documents/Foundation Models Course Projects")
os.chdir(WORKDIR)

print(f"Working directory: {WORKDIR}")
print(f"Directory exists: {WORKDIR.exists()}")


Working directory: /Users/blank/Documents/Foundation Models Course Projects
Directory exists: True


# 2. Dependency Installation


In [30]:
%pip install -q pymupdf faiss-cpu pandas tqdm pyarrow google-generativeai numpy



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [None]:
from dataclasses import dataclass
from typing import List, Optional

@dataclass
class LegalRagConfig:
    # Raw data paths
    raw_pdf_dir: Path = WORKDIR / "Dataset" / "all_pdfs"
    
    # Processing pipeline paths (all under rag_pipeline folder)
    extracted_dir: Path = WORKDIR / "rag_pipeline" / "1_extracted"
    normalized_dir: Path = WORKDIR / "rag_pipeline" / "2_normalized"
    chunked_dir: Path = WORKDIR / "rag_pipeline" / "3_chunked"
    
    # Vector database
    index_path: Path = WORKDIR / "rag_pipeline" / "4_vector_db" / "faiss_index.bin"
    metadata_path: Path = WORKDIR / "rag_pipeline" / "4_vector_db" / "metadata.parquet"
    
    # OCR normalization map path
    ocr_map_path: Path = WORKDIR / "Dataset" / "scripts" / "glossary" / "ocr_normalization_map.py"
    
    # Gemini API configuration
    gemini_api_key: str = ""  # Set here or use GEMINI_API_KEY environment variable
    gemini_embedding_model: str = "models/text-embedding-004"
    
    # Processing parameters
    batch_size_docs: Optional[int] = None  # None = process all PDFs, or set a number to limit
    embedding_batch_size: int = 100  # Chunks per Gemini embedding API call
    
    # Chunk size limits (to avoid API payload size errors)
    max_chunk_size_bytes: int = 30_000_000  # ~30MB max per chunk (Gemini limit is 40MB, but we leave buffer)
    max_chunk_chars: int = 10_000_000  # ~10M characters max per chunk
    
    # Retrieval parameters
    top_k_retrieval: int = 2  # Hard limit: 2 chunks for simplification (max 3 only if sentence has conditions+exceptions)

cfg = LegalRagConfig()

# Create all necessary directories
for path in [
    cfg.extracted_dir,
    cfg.normalized_dir,
    cfg.chunked_dir,
    cfg.index_path.parent,
]:
    path.mkdir(parents=True, exist_ok=True)

print(f"PDF directory: {cfg.raw_pdf_dir}")
print(f"PDF directory exists: {cfg.raw_pdf_dir.exists()}")

# Load Gemini API key from file or environment
def load_api_key_from_file(filepath: Path = None) -> str:
    """Load API key from .gemini_api_key file."""
    if filepath is None:
        possible_paths = [
            WORKDIR / ".gemini_api_key",
            WORKDIR / "gemini_api_key.txt",
            Path.home() / ".gemini_api_key",
        ]
        for path in possible_paths:
            if path.exists():
                try:
                    content = path.read_text().strip()
                    if "=" in content:
                        key = content.split("=", 1)[1].strip().strip('"').strip("'")
                    else:
                        key = content.strip()
                    return key
                except Exception as e:
                    print(f"[WARN] Error reading API key file {path}: {e}")
                    continue
    return ""

# Load API key
if not cfg.gemini_api_key:
    cfg.gemini_api_key = load_api_key_from_file()
    if cfg.gemini_api_key:
        print("Gemini API key loaded from file")
    else:
        cfg.gemini_api_key = os.getenv("GEMINI_API_KEY", "")
        if cfg.gemini_api_key:
            print("Gemini API key found in environment variable")
        else:
            print("WARNING: Gemini API key not found. Please create .gemini_api_key file or set GEMINI_API_KEY environment variable")

cfg


PDF directory: /Users/blank/Documents/Foundation Models Course Projects/Dataset/all_pdfs
PDF directory exists: True
Gemini API key loaded from file


LegalRagConfig(raw_pdf_dir=PosixPath('/Users/blank/Documents/Foundation Models Course Projects/Dataset/all_pdfs'), extracted_dir=PosixPath('/Users/blank/Documents/Foundation Models Course Projects/rag_pipeline/1_extracted'), normalized_dir=PosixPath('/Users/blank/Documents/Foundation Models Course Projects/rag_pipeline/2_normalized'), chunked_dir=PosixPath('/Users/blank/Documents/Foundation Models Course Projects/rag_pipeline/3_chunked'), index_path=PosixPath('/Users/blank/Documents/Foundation Models Course Projects/rag_pipeline/4_vector_db/faiss_index.bin'), metadata_path=PosixPath('/Users/blank/Documents/Foundation Models Course Projects/rag_pipeline/4_vector_db/metadata.parquet'), ocr_map_path=PosixPath('/Users/blank/Documents/Foundation Models Course Projects/Dataset/scripts/glossary/ocr_normalization_map.py'), gemini_api_key='REPLACE WITH YOUR API KEY HERE', gemini_embedding_model='models/text-embedding-004', batch_size_docs=None, embedding_batch_size=100, max_chunk_size_bytes=300

In [32]:
import json
import re
import hashlib
from typing import Dict, List
import fitz  # pymupdf
import pandas as pd
from tqdm.auto import tqdm

# Load OCR normalization map
def load_ocr_map() -> Dict[str, str]:
    """Load OCR normalization map from the existing script."""
    try:
        import importlib.util
        spec = importlib.util.spec_from_file_location("ocr_map", cfg.ocr_map_path)
        ocr_module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(ocr_module)
        return getattr(ocr_module, "OCR_MAP", {})
    except Exception as e:
        print(f"[WARN] Could not load OCR map: {e}")
        return {}

OCR_MAP = load_ocr_map()
print(f"Loaded {len(OCR_MAP)} OCR normalization rules")

def apply_ocr_normalization(text: str) -> str:
    """Apply OCR normalization to text."""
    normalized = text
    for src, dst in OCR_MAP.items():
        normalized = normalized.replace(src, dst)
    return normalized

def normalize_geez_numerals(text: str) -> str:
    """Convert Ge'ez numerals to Arabic numerals."""
    # Ge'ez numeral mapping (basic)
    geez_to_arabic = {
        "፩": "1", "፪": "2", "፫": "3", "፬": "4", "፭": "5",
        "፮": "6", "፯": "7", "፰": "8", "፱": "9", "፲": "10",
    }
    normalized = text
    for geez, arabic in geez_to_arabic.items():
        normalized = normalized.replace(geez, arabic)
    return normalized

def clean_legal_text(text: str) -> str:
    """Clean legal text: remove headers, footers, page numbers, English content."""
    # Remove common headers/footers
    drop_patterns = [
        r"ፌዴራል ነጋሪት ጋዜጣ",
        r"የኢትዮጵያ ፌዴራላዊ ዲሞክራሲያዊ ሪፐብሊክ",
        r"ሕዝብ ተወካዮች ምክር ቤት",
        r"አዋጅ ቁጥር",
        r"ገጽ\s*\d+",  # Page numbers
        r"ዓመት ቁጥር",
    ]
    
    cleaned = text
    for pattern in drop_patterns:
        cleaned = re.sub(pattern, "", cleaned)
    
    # Remove lines that are primarily English
    lines = cleaned.split("\n")
    amharic_lines = []
    ethiopic_re = re.compile(r"[\u1200-\u137F\u1380-\u139F\u2D80-\u2DDF\uAB00-\uAB2F]")
    
    for line in lines:
        line = line.strip()
        if not line:
            continue
        # Keep only lines with Ethiopic characters
        if ethiopic_re.search(line):
            amharic_lines.append(line)
    
    # Reconstruct text
    cleaned = "\n".join(amharic_lines)
    
    # Normalize whitespace
    cleaned = re.sub(r"\n{3,}", "\n\n", cleaned)
    cleaned = re.sub(r"\s+", " ", cleaned)
    
    return cleaned.strip()

def extract_pdf_text(pdf_path: Path) -> str:
    """Extract text from PDF using pymupdf."""
    try:
        doc = fitz.open(str(pdf_path))
        text_parts = []
        for page in doc:
            text = page.get_text("text")
            if text:
                text_parts.append(text)
        doc.close()
        return "\n".join(text_parts)
    except Exception as e:
        print(f"[ERROR] Failed to extract {pdf_path}: {e}")
        return ""

def get_file_hash(filepath: Path) -> str:
    """Compute MD5 hash of file content."""
    hash_md5 = hashlib.md5()
    try:
        with open(filepath, "rb") as f:
            for chunk in iter(lambda: f.read(4096), b""):
                hash_md5.update(chunk)
        return hash_md5.hexdigest()
    except Exception:
        return None

def extract_and_normalize_pdfs(max_docs: int = None) -> List[Dict]:
    """Extract and normalize PDFs. Returns list of processed documents."""
    # Find all PDFs
    pdfs = sorted(cfg.raw_pdf_dir.glob("*.pdf"))
    print(f"Found {len(pdfs)} PDF files")
    
    # Check already processed files
    processed_paths = set()
    extracted_files = sorted(cfg.extracted_dir.glob("*.json"))
    
    for json_file in extracted_files:
        try:
            payload = json.loads(json_file.read_text(encoding="utf-8"))
            source_path = payload.get("source_path")
            if source_path:
                processed_paths.add(str(Path(source_path).resolve()))
        except Exception:
            pass
    
    print(f"Already processed: {len(processed_paths)} files")
    
    # Filter pending files
    pending = []
    for pdf in pdfs:
        if str(pdf.resolve()) not in processed_paths:
            pending.append(pdf)
    
    print(f"Pending to process: {len(pending)} files")
    
    # Apply limit
    if max_docs is not None:
        pending = pending[:max_docs]
    
    results = []
    
    for pdf_path in tqdm(pending, desc="Extracting & Normalizing", unit="file"):
        try:
            # Extract raw text
            raw_text = extract_pdf_text(pdf_path)
            if not raw_text or len(raw_text.strip()) < 100:
                print(f"[WARN] {pdf_path.name} has insufficient text, skipping")
                continue
            
            # Clean text
            cleaned_text = clean_legal_text(raw_text)
            
            # Apply OCR normalization
            normalized_text = apply_ocr_normalization(cleaned_text)
            
            # Normalize numerals
            normalized_text = normalize_geez_numerals(normalized_text)
            
            # Save extracted text
            extracted_path = cfg.extracted_dir / f"{pdf_path.stem}.json"
            payload = {
                "source_path": str(pdf_path),
                "source_title": pdf_path.stem,
                "raw_text": raw_text,
                "cleaned_text": cleaned_text,
                "normalized_text": normalized_text,
                "file_hash": get_file_hash(pdf_path),
            }
            extracted_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
            
            # Save normalized text
            normalized_path = cfg.normalized_dir / f"{pdf_path.stem}.txt"
            normalized_path.write_text(normalized_text, encoding="utf-8")
            
            results.append({
                "source": str(pdf_path),
                "source_title": pdf_path.stem,
                "extracted": str(extracted_path),
                "normalized": str(normalized_path),
                "text_length": len(normalized_text),
            })
        except Exception as e:
            print(f"[ERROR] Failed to process {pdf_path}: {e}")
            continue
    
    print(f"\nExtracted & normalized {len(results)} documents")
    return results

# Run extraction
extraction_results = extract_and_normalize_pdfs(max_docs=cfg.batch_size_docs)


Loaded 93 OCR normalization rules
Found 94 PDF files
Already processed: 94 files
Pending to process: 0 files


Extracting & Normalizing: 0file [00:00, ?file/s]


Extracted & normalized 0 documents





In [33]:
import uuid
from datetime import datetime

def detect_article_structure(text: str) -> List[Dict]:
    """
    Detect articles (አንቀጽ) and sub-articles (ንዑስ አንቀጽ) in legal text.
    Returns list of article chunks with metadata.
    """
    articles = []
    
    # Enhanced patterns for article detection - multiple formats
    # Pattern 1: Standard "አንቀጽ" with number
    # Pattern 2: "አንቀጽ" with "ቁጥር" (number)
    # Pattern 3: Just number at start of line (common in some formats)
    # Pattern 4: Article with parentheses or brackets
    article_patterns = [
        re.compile(r"አንቀጽ\s*(?:ቁጥር\s*)?[:\s]*(\d+)", re.IGNORECASE),
        re.compile(r"^[\s]*(\d+)[\.\)]\s", re.MULTILINE),  # Number at start of line
        re.compile(r"አንቀጽ\s*[\(\[]\s*(\d+)\s*[\)\]]", re.IGNORECASE),
        re.compile(r"አንቀጽ\s*(\d+)", re.IGNORECASE),
    ]
    
    # Sub-article patterns
    sub_article_patterns = [
        re.compile(r"ንዑስ\s*አንቀጽ\s*(?:ቁጥር\s*)?[:\s]*(\d+)", re.IGNORECASE),
        re.compile(r"ንዑስ\s*አንቀጽ\s*(\d+)", re.IGNORECASE),
    ]
    
    # Split text into lines for processing
    lines = text.split("\n")
    
    current_article = None
    current_article_num = None
    current_article_text = []
    current_sub_articles = []
    
    for i, line in enumerate(lines):
        line = line.strip()
        if not line:
            continue
        
        # Check if this line starts a new article (try all patterns)
        article_match = None
        for pattern in article_patterns:
            article_match = pattern.search(line)
            if article_match:
                break
        
        if article_match:
            # Save previous article if exists
            if current_article_num is not None and current_article_text:
                article_text = "\n".join(current_article_text).strip()
                if article_text:
                    articles.append({
                        "article_number": current_article_num,
                        "text": article_text,
                        "sub_articles": current_sub_articles.copy(),
                        "start_line": i - len(current_article_text),
                        "end_line": i,
                    })
            
            # Start new article
            current_article_num = article_match.group(1)
            current_article_text = [line]
            current_sub_articles = []
            continue
        
        # Check if this line is a sub-article (try all patterns)
        sub_article_match = None
        for pattern in sub_article_patterns:
            sub_article_match = pattern.search(line)
            if sub_article_match:
                break
        
        if sub_article_match:
            sub_article_num = sub_article_match.group(1)
            if sub_article_num not in current_sub_articles:
                current_sub_articles.append(sub_article_num)
        
        # Add line to current article
        if current_article_num is not None:
            current_article_text.append(line)
    
    # Save last article
    if current_article_num is not None and current_article_text:
        article_text = "\n".join(current_article_text).strip()
        if article_text:
            articles.append({
                "article_number": current_article_num,
                "text": article_text,
                "sub_articles": current_sub_articles.copy(),
                "start_line": len(lines) - len(current_article_text),
                "end_line": len(lines),
            })
    
    return articles

def detect_legal_function(text: str) -> str:
    """
    Detect the legal function of a text chunk.
    Returns: 'obligation', 'prohibition', 'permission', 'condition', 'definition', or 'other'
    """
    text_lower = text.lower()
    
    # Obligation patterns (must, shall, required)
    obligation_patterns = [
        r'አለበት',
        r'ይገባል',
        r'ይገደዳል',
        r'አስፈላጊ',
        r'ግዴታ',
        r'መሆን አለበት',
        r'መከበር አለበት',
    ]
    
    # Prohibition patterns (may not, forbidden)
    prohibition_patterns = [
        r'አይፈቀድም',
        r'አይቻልም',
        r'አይፈቀድ',
        r'ክልከላ',
        r'ተከለከለ',
        r'አይፈቀድም',
    ]
    
    # Permission patterns (may, allowed)
    permission_patterns = [
        r'ይፈቀዳል',
        r'ይችላል',
        r'መሆን ይችላል',
        r'ፈቃድ',
    ]
    
    # Condition patterns (if, when, unless)
    condition_patterns = [
        r'ከ\.\.\.\s*በስተቀር',
        r'ካል\.\.\.\s*በስተቀር',
        r'ቢሆንም',
        r'እስከሚ\.\.\.\s*ድረስ',
        r'ከሆነ',
        r'በሆነ ሁኔታ',
        r'ካልሆነ',
    ]
    
    # Definition patterns (means, refers to, is defined as)
    definition_patterns = [
        r'ማለት',
        r'ነው',
        r'የሚለው',
        r'የሚመለከት',
        r'ትርጉም',
        r'ፍቺ',
    ]
    
    # Check patterns in order of specificity
    for pattern in obligation_patterns:
        if re.search(pattern, text):
            return 'obligation'
    
    for pattern in prohibition_patterns:
        if re.search(pattern, text):
            return 'prohibition'
    
    for pattern in condition_patterns:
        if re.search(pattern, text):
            return 'condition'
    
    for pattern in permission_patterns:
        if re.search(pattern, text):
            return 'permission'
    
    for pattern in definition_patterns:
        if re.search(pattern, text):
            return 'definition'
    
    return 'other'

def split_into_functional_units(text: str, article_num: str) -> List[Dict]:
    """
    Split article text into functional legal units (obligation, prohibition, etc.).
    Returns list of functional chunks with their detected function type.
    """
    functional_chunks = []
    
    # Split by sentence boundaries (Amharic sentence end marker: ።)
    sentences = re.split(r'([።]+)', text)
    
    # Reconstruct sentences (keep the delimiter with previous sentence)
    reconstructed_sentences = []
    for i in range(0, len(sentences), 2):
        if i < len(sentences):
            sentence = sentences[i]
            if i + 1 < len(sentences):
                sentence += sentences[i + 1]
            if sentence.strip():
                reconstructed_sentences.append(sentence.strip())
    
    # Group sentences by legal function
    current_function = None
    current_chunk = []
    
    for sentence in reconstructed_sentences:
        if not sentence or len(sentence) < 10:
            continue
        
        # Detect function of this sentence
        sentence_function = detect_legal_function(sentence)
        
        # If function changes or chunk is getting large, save current chunk
        if current_function and current_function != sentence_function and current_chunk:
            chunk_text = " ".join(current_chunk).strip()
            if len(chunk_text) > 20:  # Minimum meaningful chunk size
                functional_chunks.append({
                    "function": current_function,
                    "text": chunk_text,
                    "article_number": article_num,
                })
            current_chunk = []
        
        current_function = sentence_function
        current_chunk.append(sentence)
        
        # If chunk is getting too large, split it
        chunk_text = " ".join(current_chunk)
        if len(chunk_text.encode('utf-8')) > 500_000:  # ~500KB per functional chunk
            if current_chunk:
                chunk_text = " ".join(current_chunk[:-1]).strip()
                if len(chunk_text) > 20:
                    functional_chunks.append({
                        "function": current_function,
                        "text": chunk_text,
                        "article_number": article_num,
                    })
                current_chunk = [current_chunk[-1]]  # Keep last sentence for next chunk
    
    # Save final chunk
    if current_chunk:
        chunk_text = " ".join(current_chunk).strip()
        if len(chunk_text) > 20:
            functional_chunks.append({
                "function": current_function or 'other',
                "text": chunk_text,
                "article_number": article_num,
            })
    
    # If no functional chunks created (very short article), return as single chunk
    if not functional_chunks:
        functional_chunks.append({
            "function": detect_legal_function(text),
            "text": text,
            "article_number": article_num,
        })
    
    return functional_chunks

def extract_law_metadata(source_title: str) -> Dict:
    """Extract metadata about the law from source title."""
    metadata = {
        "document_title": source_title,
        "law_name": source_title,
        "domain": "unknown",
        "year": None,
    }
    
    # Enhanced domain detection from title
    title_lower = source_title.lower()
    title_amharic = source_title
    
    # Criminal law patterns
    if any(term in title_lower for term in ["criminal", "ወንጀል", "procedure", "prosecutor"]):
        metadata["domain"] = "criminal"
    # Commercial law patterns
    elif any(term in title_lower for term in ["commercial", "ንግድ", "business", "trade"]):
        metadata["domain"] = "commercial"
    # Civil law patterns
    elif any(term in title_lower for term in ["civil", "ነጋሪት", "contract", "tort"]):
        metadata["domain"] = "civil"
    # Family law patterns
    elif any(term in title_lower for term in ["family", "ቤተሰብ", "marriage", "divorce"]):
        metadata["domain"] = "family"
    # Labor law
    elif any(term in title_lower for term in ["labour", "labor", "employment", "worker"]):
        metadata["domain"] = "labor"
    # Court decisions
    elif any(term in title_lower for term in ["court", "decision", "cassation", "supreme", "volume"]):
        metadata["domain"] = "judicial"
    # Procedure
    elif any(term in title_lower for term in ["procedure", "procedural", "trial"]):
        metadata["domain"] = "procedure"
    
    # Try to extract year
    year_match = re.search(r"(19|20)\d{2}", source_title)
    if year_match:
        metadata["year"] = int(year_match.group())
    
    return metadata

def build_article_chunks(max_docs: int = None) -> pd.DataFrame:
    """Build article-level chunks from normalized documents."""
    extracted_files = sorted(cfg.extracted_dir.glob("*.json"))
    
    if max_docs is not None:
        extracted_files = extracted_files[:max_docs]
    
    chunk_records = []
    
    for json_path in tqdm(extracted_files, desc="Chunking Articles", unit="doc"):
        try:
            payload = json.loads(json_path.read_text(encoding="utf-8"))
            normalized_text = payload.get("normalized_text", "")
            source_title = payload.get("source_title", json_path.stem)
            source_path = payload.get("source_path", str(json_path))
            
            if not normalized_text:
                continue
            
            # Detect article structure
            articles = detect_article_structure(normalized_text)
            
            if not articles:
                # Fallback: if no articles detected, treat entire document as one chunk
                print(f"[WARN] No articles detected in {source_title}, treating as single chunk")
                articles = [{
                    "article_number": "1",
                    "text": normalized_text,
                    "sub_articles": [],
                    "start_line": 0,
                    "end_line": len(normalized_text.split("\n")),
                }]
            
            # Extract law metadata
            law_metadata = extract_law_metadata(source_title)
            
            # Create doc_id
            doc_id = uuid.uuid5(uuid.NAMESPACE_URL, f"{source_path}_{source_title}").hex
            
            # Create functional chunks for each article
            for article_idx, article in enumerate(articles):
                article_text = article["text"]
                article_num = article["article_number"]
                
                # Split article into functional units (obligation, prohibition, etc.)
                functional_units = split_into_functional_units(article_text, article_num)
                
                # Create chunk for each functional unit
                for func_idx, func_unit in enumerate(functional_units):
                    func_text = func_unit["text"]
                    func_type = func_unit["function"]
                    
                    # Check chunk size and split/truncate if too large
                    chunk_size_bytes = len(func_text.encode('utf-8'))
                    chunk_size_chars = len(func_text)
                    
                    if chunk_size_bytes > cfg.max_chunk_size_bytes or chunk_size_chars > cfg.max_chunk_chars:
                        # Truncate to safe size (functional chunks should be smaller)
                        max_chars = min(cfg.max_chunk_chars, 1_000_000)  # Cap at 1M chars for functional chunks
                        func_text = func_text[:max_chars] + "... [truncated]"
                        print(f"[WARN] Functional chunk {func_type} in article {article_num} ({source_title}) was truncated")
                    
                    # Create unique chunk ID
                    if len(functional_units) > 1:
                        chunk_id = f"{doc_id}-article-{article_num}-func{func_idx+1}-{func_type}"
                        display_article_num = f"{article_num}.{func_idx+1}"
                    else:
                        chunk_id = f"{doc_id}-article-{article_num}-{func_type}"
                        display_article_num = article_num
                    
                    chunk_records.append({
                        "chunk_id": chunk_id,
                        "doc_id": doc_id,
                        "article_number": display_article_num,
                        "sub_article_numbers": ",".join(article["sub_articles"]) if article["sub_articles"] else None,
                        "legal_function": func_type,  # NEW: functional metadata
                        "text": func_text,
                        "source_path": source_path,
                        "source_title": source_title,
                        "document_title": law_metadata["document_title"],
                        "law_name": law_metadata["law_name"],
                        "domain": law_metadata["domain"],
                        "year": law_metadata["year"],
                        "created_at": datetime.utcnow().isoformat(),
                    })
        except Exception as e:
            print(f"[ERROR] Failed to chunk {json_path}: {e}")
            continue
    
    df = pd.DataFrame(chunk_records)
    
    if df.empty:
        print("No chunks produced. Check extraction step.")
        return df
    
    # Save chunk manifest
    manifest_path = cfg.chunked_dir / "chunk_manifest.parquet"
    df.to_parquet(manifest_path, index=False)
    print(f"Saved chunk manifest with {len(df)} chunks → {manifest_path}")
    
    # Print statistics
    print(f"\nChunking Statistics:")
    print(f"  Total chunks: {len(df)}")
    print(f"  Unique documents: {df['doc_id'].nunique()}")
    print(f"  Domains: {df['domain'].value_counts().to_dict()}")
    if 'legal_function' in df.columns:
        print(f"  Legal functions: {df['legal_function'].value_counts().to_dict()}")
    
    return df

# Build chunks
chunk_df = build_article_chunks()
if not chunk_df.empty:
    chunk_df.head()


Chunking Articles:  16%|█▌        | 15/94 [00:00<00:00, 120.59doc/s]

[WARN] No articles detected in 1356, treating as single chunk
[WARN] No articles detected in 1364, treating as single chunk
[WARN] No articles detected in 1378, treating as single chunk
[WARN] No articles detected in Criminal Code (New) (Amharic), treating as single chunk
[WARN] No articles detected in civil-procedure-module-revised, treating as single chunk


Chunking Articles:  43%|████▎     | 40/94 [00:00<00:00, 73.91doc/s] 

[WARN] No articles detected in federal-supreme-court-decisions-volume-1-2-3, treating as single chunk
[WARN] No articles detected in federal-supreme-court-decisions-volume-4, treating as single chunk
[WARN] No articles detected in federal-supreme-court-decisions-volume-5, treating as single chunk
[WARN] No articles detected in federal-supreme-court-decisions-volume-6, treating as single chunk
[WARN] No articles detected in federal-supreme-court-decisions-volume-7, treating as single chunk
[WARN] No articles detected in labour-law, treating as single chunk
[WARN] No articles detected in module-on-professional-ethics, treating as single chunk
[WARN] No articles detected in modules-on-criminal-procedure-law, treating as single chunk
[WARN] No articles detected in principles-for-the-judiciary, treating as single chunk
[WARN] No articles detected in revised-criminal-procedure-module, treating as single chunk


Chunking Articles:  70%|███████   | 66/94 [00:00<00:00, 123.62doc/s]

[WARN] No articles detected in revised-criminal-procedure-module1, treating as single chunk
[WARN] No articles detected in revised-module-on-the-justice-system-and-the-role-of-justice-organs1, treating as single chunk
[WARN] No articles detected in the-role-of-prosecutor-in-criminal-investigation, treating as single chunk
[WARN] No articles detected in trial-and-pre-trial-managment, treating as single chunk
[WARN] No articles detected in trial-and-pre-trial, treating as single chunk
[WARN] No articles detected in volume 1-3, treating as single chunk


Chunking Articles: 100%|██████████| 94/94 [00:00<00:00, 101.35doc/s]


[WARN] No articles detected in volume 4, treating as single chunk
[WARN] No articles detected in volume 5, treating as single chunk
[WARN] No articles detected in volume 6, treating as single chunk
[WARN] No articles detected in volume 7, treating as single chunk
[WARN] No articles detected in አዋጅ-ቁጥር-1354, treating as single chunk
Saved chunk manifest with 1228 chunks → /Users/blank/Documents/Foundation Models Course Projects/rag_pipeline/3_chunked/chunk_manifest.parquet

Chunking Statistics:
  Total chunks: 1228
  Unique documents: 65
  Domains: {'commercial': 793, 'judicial': 324, 'unknown': 80, 'criminal': 15, 'labor': 14, 'procedure': 2}
  Legal functions: {'other': 402, 'obligation': 292, 'definition': 248, 'permission': 157, 'condition': 127, 'prohibition': 2}


In [34]:
import faiss
import numpy as np
import time

# Normalize Ge'ez numerals function (needed for embedding)
def normalize_geez_numerals(text: str) -> str:
    """Convert Ge'ez numerals to Arabic numerals."""
    geez_to_arabic = {
        "፩": "1", "፪": "2", "፫": "3", "፬": "4", "፭": "5",
        "፮": "6", "፯": "7", "፰": "8", "፱": "9", "፲": "10",
    }
    normalized = text
    for geez, arabic in geez_to_arabic.items():
        normalized = normalized.replace(geez, arabic)
    return normalized

# Configure Gemini API
try:
    import google.generativeai as genai
    GEMINI_AVAILABLE = True
    if cfg.gemini_api_key:
        genai.configure(api_key=cfg.gemini_api_key)
        print("Gemini API configured for embeddings")
    else:
        print("[WARN] Gemini API key not found")
        GEMINI_AVAILABLE = False
except ImportError:
    print("[WARN] google-generativeai not installed")
    GEMINI_AVAILABLE = False

def embed_with_gemini(texts: List[str], task_type: str = "RETRIEVAL_DOCUMENT") -> np.ndarray:
    """Generate embeddings using Gemini API with batching and rate limiting."""
    if not GEMINI_AVAILABLE:
        raise ValueError("Gemini not available")
    
    embeddings = []
    batch_size = cfg.embedding_batch_size
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Gemini embeddings"):
        batch = texts[i:i+batch_size]
        try:
            result = genai.embed_content(
                model=cfg.gemini_embedding_model,
                content=batch,
                task_type=task_type
            )
            batch_embeddings = result['embedding']
            embeddings.extend(batch_embeddings)
            
            # Rate limiting: Free tier limit is 100 requests/minute
            if i + batch_size < len(texts):
                time.sleep(0.7)
        except Exception as e:
            print(f"[ERROR] Batch {i} failed: {e}")
            # Retry once after delay
            time.sleep(2)
            try:
                result = genai.embed_content(
                    model=cfg.gemini_embedding_model,
                    content=batch,
                    task_type=task_type
            )
                batch_embeddings = result['embedding']
                embeddings.extend(batch_embeddings)
            except Exception as e2:
                print(f"[ERROR] Retry failed for batch {i}: {e2}")
                # Fallback to zeros (will cause issues, but allows continuation)
                embeddings.extend([np.zeros(768) for _ in batch])
    
    return np.array(embeddings, dtype=np.float32)

def build_faiss_index(df: pd.DataFrame) -> faiss.Index:
    """
    Build FAISS index from chunks with size validation.
    Ensures numerals are normalized before embedding to improve semantic similarity.
    """
    if df.empty:
        raise ValueError("No chunks to embed")
    
    # Filter and validate chunks before embedding
    print(f"Validating {len(df)} chunks for size limits...")
    
    valid_chunks = []
    valid_indices = []
    skipped_count = 0
    
    for idx, row in df.iterrows():
        text = row["text"]
        
        # Ensure numerals are normalized before embedding (critical for semantic similarity)
        # This should already be done, but double-check to avoid numeral-driven similarity
        text = normalize_geez_numerals(text)
        
        chunk_size_bytes = len(text.encode('utf-8'))
        chunk_size_chars = len(text)
        
        # Skip chunks that are too large (shouldn't happen after chunking fix, but double-check)
        if chunk_size_bytes > cfg.max_chunk_size_bytes:
            print(f"[WARN] Skipping chunk {row.get('chunk_id', idx)}: too large ({chunk_size_bytes:,} bytes)")
            skipped_count += 1
            continue
        
        # Truncate if slightly over character limit
        if chunk_size_chars > cfg.max_chunk_chars:
            text = text[:cfg.max_chunk_chars] + "... [truncated]"
            print(f"[WARN] Truncating chunk {row.get('chunk_id', idx)}: {chunk_size_chars:,} chars")
        
        valid_chunks.append(text)
        valid_indices.append(idx)
    
    if skipped_count > 0:
        print(f"Skipped {skipped_count} chunks that were too large")
    
    if not valid_chunks:
        raise ValueError("No valid chunks to embed after size validation")
    
    print(f"Embedding {len(valid_chunks)} valid chunks...")
    
    if GEMINI_AVAILABLE:
        embeddings = embed_with_gemini(valid_chunks, task_type="RETRIEVAL_DOCUMENT")
    else:
        raise ValueError("Gemini embeddings required but not available")
    
    # Filter dataframe to only include valid chunks
    df_valid = df.loc[valid_indices].copy()
    
    dimension = embeddings.shape[1]
    print(f"Embedding dimension: {dimension}")
    
    # Create FAISS index
    if cfg.index_path.exists():
        print(f"Removing existing index to create fresh one")
        cfg.index_path.unlink()
    
    print("Creating new HNSW index")
    index = faiss.IndexHNSWFlat(dimension, 32)
    index.hnsw.efConstruction = 200
    index.hnsw.efSearch = 64
    index.add(embeddings.astype("float32"))
    
    # Save metadata (only valid chunks)
    df_valid.to_parquet(cfg.metadata_path, index=False)
    
    # Save index
    faiss.write_index(index, str(cfg.index_path))
    
    print(f"Index now holds {index.ntotal} vectors; saved to {cfg.index_path}")
    print(f"Metadata has {len(df_valid)} rows; saved to {cfg.metadata_path}")
    
    # Verify they match
    if index.ntotal != len(df_valid):
        print(f"[WARN] Mismatch: index has {index.ntotal} vectors but metadata has {len(df_valid)} rows")
    else:
        print(f"✓ Index and metadata are in sync ({index.ntotal} vectors/rows)")
    
    return index

# Build index if chunks exist
if not chunk_df.empty:
    faiss_index = build_faiss_index(chunk_df)
else:
    print("No chunks to embed. Run chunking step first.")


Gemini API configured for embeddings
Validating 1228 chunks for size limits...
Embedding 1228 valid chunks...


Gemini embeddings: 100%|██████████| 13/13 [00:44<00:00,  3.45s/it]

Embedding dimension: 768
Removing existing index to create fresh one
Creating new HNSW index
Index now holds 1228 vectors; saved to /Users/blank/Documents/Foundation Models Course Projects/rag_pipeline/4_vector_db/faiss_index.bin
Metadata has 1228 rows; saved to /Users/blank/Documents/Foundation Models Course Projects/rag_pipeline/4_vector_db/metadata.parquet
✓ Index and metadata are in sync (1228 vectors/rows)





In [35]:
from typing import Tuple

# Normalize Ge'ez numerals function (needed for retrieval)
def normalize_geez_numerals(text: str) -> str:
    """Convert Ge'ez numerals to Arabic numerals."""
    geez_to_arabic = {
        "፩": "1", "፪": "2", "፫": "3", "፬": "4", "፭": "5",
        "፮": "6", "፯": "7", "፰": "8", "፱": "9", "፲": "10",
    }
    normalized = text
    for geez, arabic in geez_to_arabic.items():
        normalized = normalized.replace(geez, arabic)
    return normalized

def load_index_and_metadata() -> Tuple[faiss.Index, pd.DataFrame]:
    """Load FAISS index and metadata."""
    if not cfg.index_path.exists():
        raise FileNotFoundError("FAISS index not found. Run embedding step first.")
    
    index = faiss.read_index(str(cfg.index_path))
    
    if not cfg.metadata_path.exists():
        raise FileNotFoundError(f"Metadata file missing: {cfg.metadata_path}")
    
    df = pd.read_parquet(cfg.metadata_path)
    
    return index, df

def retrieve_legal_context(
    query: str,
    top_k: int = None,
    domain_filter: Optional[str] = None,
    function_filter: Optional[str] = None
) -> pd.DataFrame:
    """
    Retrieve relevant legal articles for a query.
    
    For sentence simplification: retrieve 2 chunks (hard limit).
    More context does not improve simplification quality and can cause:
    - Copying legal language instead of simplifying
    - Losing aggressive simplification
    - Cognitive overload for the generator
    
    Args:
        query: Legal sentence or query text (in Amharic)
        top_k: Number of chunks to retrieve (default: cfg.top_k_retrieval, hard cap at 2 for simplification)
        domain_filter: Optional domain filter ("criminal", "commercial", "civil", "family")
        function_filter: Optional legal function filter ("obligation", "prohibition", "permission", "condition", "definition")
    
    Returns:
        DataFrame with retrieved chunks and metadata
    """
    if top_k is None:
        top_k = cfg.top_k_retrieval
    
    # Hard cap at 2 for simplification (max 3 only if sentence has explicit conditions+exceptions)
    # This prevents context overload and ensures cleaner, more consistent simplification
    if top_k > 2:
        top_k = 2
    
    index, df = load_index_and_metadata()
    
    # Verify index and metadata are in sync
    index_size = index.ntotal
    metadata_size = len(df)
    if index_size != metadata_size:
        print(f"[WARN] Index size ({index_size}) doesn't match metadata size ({metadata_size})")
        max_valid_idx = min(index_size, metadata_size) - 1
    else:
        max_valid_idx = metadata_size - 1
    
    # Normalize query numerals before embedding (critical for semantic similarity)
    query_normalized = normalize_geez_numerals(query)
    
    # Generate query embedding
    if GEMINI_AVAILABLE:
        query_emb = embed_with_gemini([query_normalized], task_type="RETRIEVAL_QUERY")[0]
    else:
        raise ValueError("Gemini embeddings required but not available")
    
    query_emb = query_emb.astype("float32").reshape(1, -1)
    
    # Search (retrieve more if filtering to ensure we get enough after filtering)
    search_k = top_k * 5 if (domain_filter or function_filter) else top_k
    distances, indices = index.search(query_emb, search_k)
    
    hits = []
    for rank, (score, idx) in enumerate(zip(distances[0], indices[0])):
        if idx == -1 or idx > max_valid_idx or idx < 0:
            continue
        
        try:
            row = df.iloc[idx].to_dict()
            
            # Apply domain filter if specified
            if domain_filter and row.get("domain") != domain_filter:
                continue
            
            # Apply function filter if specified
            if function_filter and row.get("legal_function") != function_filter:
                continue
            
            row.update({
                "rank": len(hits) + 1,
                "score": float(score),
            })
            hits.append(row)
            
            if len(hits) >= top_k:
                break
        except IndexError as e:
            print(f"[WARN] Failed to access row {idx}: {e}")
            continue
    
    return pd.DataFrame(hits)

# Test retrieval
if cfg.index_path.exists():
    test_query = "የንግድ ስምምነት መፈጸም"
    print(f"Testing retrieval with query: {test_query}")
    results = retrieve_legal_context(test_query, top_k=2)  # Hard limit: 2 chunks for simplification
    if not results.empty:
        print(f"\nRetrieved {len(results)} chunks:")
        for _, row in results.iterrows():
            print(f"\nRank {row['rank']} (score: {row['score']:.4f}):")
            print(f"  Law: {row['law_name']}")
            print(f"  Article: {row['article_number']}")
            print(f"  Domain: {row['domain']}")
            if 'legal_function' in row:
                print(f"  Function: {row['legal_function']}")
            print(f"  Text preview: {row['text'][:200]}...")
    else:
        print("No results found.")
else:
    print("Index not found. Run embedding step first.")


Testing retrieval with query: የንግድ ስምምነት መፈጸም


Gemini embeddings: 100%|██████████| 1/1 [00:00<00:00,  3.05it/s]


Retrieved 2 chunks:

Rank 1 (score: 0.7006):
  Law: Ethiopia-Commercial-Code-Amharic-Proclamation-No.-1243_2021
  Article: 1.734
  Domain: commercial
  Function: other
  Text preview: ገንዘብ ጠያቂውም በክርክር ሂዯቱ ተሳታፉ ሇመሆን አይችሌም።...

Rank 2 (score: 0.7009):
  Law: 1374
  Article: 9.2
  Domain: unknown
  Function: other
  Text preview: ቀጥታ ስርጭትን በሚመለከት ዝርዝሩ ቦርዱ በሚያወጣው መመሪያ ይወሰናል።...





# 8. Usage Example

Example of how to use the RAG system for legal text simplification context.


In [36]:
def format_retrieved_context(query: str, retrieved_chunks: pd.DataFrame) -> str:
    """
    Format retrieved chunks as context for the simplification model.
    This is what you would pass to your fine-tuned AfriByT5 model.
    """
    if retrieved_chunks.empty:
        return f"Legal sentence to simplify: {query}\n\nNo relevant legal context found."
    
    context_parts = [f"Legal sentence to simplify: {query}", "\n\nRelevant legal context:"]
    
    for _, row in retrieved_chunks.iterrows():
        function_info = f"Function: {row.get('legal_function', 'unknown')}" if 'legal_function' in row else ""
        context_parts.append(
            f"\n---\n"
            f"Law: {row['law_name']}\n"
            f"Article: {row['article_number']}\n"
            f"Domain: {row['domain']}\n"
            f"{function_info}\n" if function_info else ""
            f"Text: {row['text']}"
        )
    
    return "\n".join(context_parts)

# Example usage
example_legal_sentence = "በዚህ አዋጅ መሠረት የተወሰነው ውሳኔ በሁሉም አካላት መከበር አለበት።"

print("Example: Retrieving context for legal sentence simplification")
print(f"\nOriginal legal sentence:\n{example_legal_sentence}")

if cfg.index_path.exists():
    retrieved = retrieve_legal_context(example_legal_sentence, top_k=2)  # Hard limit: 2 chunks for simplification
    
    if not retrieved.empty:
        formatted_context = format_retrieved_context(example_legal_sentence, retrieved)
        print("\n" + "="*80)
        print("Formatted context for model:")
        print("="*80)
        print(formatted_context)
        print("\n" + "="*80)
        print("\nThis formatted context would be passed to your fine-tuned model.")
        print("The model uses this context to understand legal meaning before simplifying.")
    else:
        print("\nNo relevant context retrieved.")
else:
    print("\nIndex not found. Run embedding step first.")


Example: Retrieving context for legal sentence simplification

Original legal sentence:
በዚህ አዋጅ መሠረት የተወሰነው ውሳኔ በሁሉም አካላት መከበር አለበት።


Gemini embeddings: 100%|██████████| 1/1 [00:00<00:00,  3.04it/s]


Formatted context for model:
Legal sentence to simplify: በዚህ አዋጅ መሠረት የተወሰነው ውሳኔ በሁሉም አካላት መከበር አለበት።


Relevant legal context:

---
Law: 6cf83-23-e18ba8e18d8c.e18ca0.e18d8d.e189a4e189b5-e188a0e189a0e188ad-e188b0e1889a-e189bde1888ee189b5-e18b8de188b3e18a94e18b8ee189bd
Article: 89.17
Domain: unknown
Function: definition


---
Law: Ethiopia-Commercial-Code-Amharic-Proclamation-No.-1243_2021
Article: 1.780
Domain: commercial
Function: other



This formatted context would be passed to your fine-tuned model.
The model uses this context to understand legal meaning before simplifying.





# 9. Session Recovery & Status

Check processing status and resume from where you left off.


In [37]:
def check_processing_status() -> Dict:
    """Check the status of the RAG pipeline."""
    status = {
        "pdfs_found": len(list(cfg.raw_pdf_dir.glob("*.pdf"))),
        "pdfs_extracted": len(list(cfg.extracted_dir.glob("*.json"))),
        "pdfs_normalized": len(list(cfg.normalized_dir.glob("*.txt"))),
        "chunks_created": 0,
        "index_exists": cfg.index_path.exists(),
        "metadata_exists": cfg.metadata_path.exists(),
    }
    
    if cfg.metadata_path.exists():
        try:
            df = pd.read_parquet(cfg.metadata_path)
            status["chunks_created"] = len(df)
            status["unique_documents"] = df["doc_id"].nunique()
            status["domains"] = df["domain"].value_counts().to_dict()
        except Exception:
            pass
    
    if cfg.index_path.exists():
        try:
            index = faiss.read_index(str(cfg.index_path))
            status["index_vectors"] = index.ntotal
        except Exception:
            pass
    
    return status

status = check_processing_status()
print("RAG Pipeline Status:")
print("="*50)
for key, value in status.items():
    print(f"{key}: {value}")
print("="*50)

if status["pdfs_extracted"] < status["pdfs_found"]:
    print(f"\n⚠️  {status['pdfs_found'] - status['pdfs_extracted']} PDFs still need extraction")
    print("   Run: extract_and_normalize_pdfs(max_docs=None)")

if status["chunks_created"] == 0 and status["pdfs_extracted"] > 0:
    print(f"\n⚠️  Chunking needed for {status['pdfs_extracted']} extracted documents")
    print("   Run: build_article_chunks()")

if not status["index_exists"] and status["chunks_created"] > 0:
    print(f"\n⚠️  Indexing needed for {status['chunks_created']} chunks")
    print("   Run: build_faiss_index(chunk_df)")

if status["index_exists"] and status["chunks_created"] > 0:
    print("\n✓ RAG system is ready for retrieval!")


RAG Pipeline Status:
pdfs_found: 94
pdfs_extracted: 94
pdfs_normalized: 94
chunks_created: 1228
index_exists: True
metadata_exists: True
unique_documents: 65
domains: {'commercial': 793, 'judicial': 324, 'unknown': 80, 'criminal': 15, 'labor': 14, 'procedure': 2}
index_vectors: 1228

✓ RAG system is ready for retrieval!
