
## Table of Contents

### [0. Global Vars](#0-global-vars)
Configuration constants and system parameters for all three RAG stages

### [1. Process PDFs](#1-process-pdfs)
Document preprocessing pipeline with Vietnamese text optimization
- [1.1 Text Simple Cleaner and Metadata](#11-text-simple-cleaner-and-metadata)
- [1.2 Run PDF Processing Pipeline](#12-run-pdf-processing-pipeline)

### [2. Build RAG](#2-build-rag)
Vector database construction with hybrid BGE-M3 embeddings
- [2.1 Milvus Client Connection](#21-milvus-client-connection)
- [2.2 Milvus Builder and Search](#22-milvus-builder-and-search)
- [2.3 Embedding Model and Builder Vector Store](#23-embedding-model-and-builder-vector-store)

### [3. Retrieve from RAG](#3-retrieve-from-rag)
Interactive retrieval system with LLM answer generation
- [3.1 LLM Definition and Call Managements](#31-llm-definition-and-call-managements)
- [3.2 Load retriever and models](#32-load-retriever-and-models)
- [3.3 Define RAG Chain and Run](#33-define-rag-chain-and-run)


___
## 0. Global Vars

In [None]:
# Preprocess
CHARS_PER_TOKEN = 3
CHUNKER_MODEL = "BAAI/bge-m3"
MAX_CHUNK_TOKENS = 1024
OVERLAP_TOKENS = 128

# Build RAG
DENSE_INDEX_CONFIG = {
    "index_type": "HNSW",
    "metric_type": "COSINE",
    "params": {"M": 16, "efConstruction": 64},
}
DENSE_INDEX_FALLBACK_CONFIG = {
    "index_type": "IVF_FLAT",
    "metric_type": "COSINE",
    "params": {"nlist": 256},
}
DENSE_SEARCH_FALLBACK_PARAMS = {"metric_type": "COSINE", "params": {"nprobe": 8}}
DENSE_SEARCH_PARAMS = {"metric_type": "IP", "params": {"drop_ratio_search": 0.2}}
MILVUS_DOCKER_URI = "http://localhost:19530"
MILVUS_URI = "data/milvus.db"
USE_DOCKER_MILVUS = False

RRF_K = 30
SPARSE_INDEX_CONFIG = {
    "index_type": "SPARSE_INVERTED_INDEX",
    "metric_type": "IP",
    "params": {"drop_ratio_build": 0.2},
}
SPARSE_SEARCH_PARAMS = {"metric_type": "IP", "params": {"drop_ratio_search": 0.2}}

COLLECTION_NAME = "vndoc_rag_hybrid"
EMBED_MODEL_ID = CHUNKER_MODEL
EMBEDDING_DIM = 1024
ENCODE_KWARGS = {
    "normalize_embeddings": True,
    "batch_size": 8,
    "return_dense": True,
    "return_sparse": True,
    "return_colbert_vecs": False,
}

# Retrieve
MAX_OUTPUT_TOKENS = 4096
TEMPERATURE = 0
PROMPT = """Bạn là một chuyên gia trí tuệ nhân tạo và học máy có kiến thức sâu rộng. Hãy trả lời câu hỏi dựa trên thông tin được cung cấp từ hệ thống RAG.\n\nCÂU HỎI: {query}\n\nTHÔNG TIN THAM KHẢO:\n{context}\n\nHƯỚNG DẪN TRẢ LỜI:\n1. Trả lời bằng tiếng Việt một cách chi tiết và rõ ràng\n2. Chỉ sử dụng thông tin có trong các tài liệu được cung cấp - KHÔNG tự bịa đặt hoặc thêm thông tin\n3. Mỗi nguồn tài liệu là riêng biệt - KHÔNG trộn lẫn hoặc kết hợp thông tin từ các nguồn khác nhau một cách tùy tiện\n4. Nếu thông tin từ các nguồn khác nhau mâu thuẫn, hãy chỉ ra sự khác biệt này\n5. Giải thích các thuật ngữ kỹ thuật bằng tiếng Việt\n6. Trích dẫn rõ ràng nguồn thông tin cho mỗi phần trả lời\n7. Nếu thông tin không đủ để trả lời đầy đủ, hãy thừa nhận điều này thay vì đoán"""

DEFAULT_K = 10
RERANK_TOP_K = 3
SIMILARITY_THRESHOLD = 0.2
RERANKER_MODEL_ID = "BAAI/bge-reranker-v2-m3"

## 1. Process PDFs

In [None]:
import gc
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple

from docling.chunking import HybridChunker
from langchain.schema import Document
from langchain_docling import DoclingLoader
from langchain_docling.loader import ExportType
from tqdm.auto import tqdm
import unicodedata
import hashlib
import re
import datetime
import uuid

### 1.1 Text Simple Cleaner and Metadata

In [None]:
class TextCleaner:
    """Vietnamese text cleaner optimized for RAG preprocessing."""

    def __init__(self):
        # Allow Vietnamese letters + digits + punctuation
        self.vietnamese_chars = (
            "àáạảãâầấậẩẫăằắặẳẵ"
            "èéẹẻẽêềếệểễ"
            "ìíịỉĩ"
            "òóọỏõôồốộổỗơờớợởỡ"
            "ùúụủũưừứựửữ"
            "ỳýỵỷỹ"
            "đ"
            "ÀÁẠẢÃÂẦẤẬẨẪĂẰẮẶẲẴ"
            "ÈÉẸẺẼÊỀẾỆỂỄ"
            "ÌÍỊỈĨ"
            "ÒÓỌỎÕÔỒỐỘỔỖƠỜỚỢỞỠ"
            "ÙÚỤỦŨƯỪỨỰỬỮ"
            "ỲÝỴỶỸ"
            "Đ"
        )

    def normalize_unicode(self, text: str) -> str:
        """Normalize Unicode (NFC) for consistency."""
        return unicodedata.normalize("NFC", text)

    def clean_text(self, text: str) -> tuple[str, Dict[str, Any]]:
        """Clean text and return metadata for RAG preparation."""
        if not text or not text.strip():
            return "", {"was_empty": True}

        gc.collect()
        original_length = len(text)

        # Unicode normalization
        text = self.normalize_unicode(text)

        # Basic noise cleanup
        text = re.sub(r"[ \t]+", " ", text)  # collapse spaces/tabs
        text = re.sub(r"\n\s*\n+", "\n\n", text)  # normalize paragraph breaks
        text = re.sub(r"[.]{3,}", "...", text)  # reduce ellipses
        text = re.sub(r"[-]{3,}", "—", text)  # convert long dashes
        text = re.sub(r"[,]{2,}", ",", text)  # collapse commas

        # Remove stray non-text characters (OCR artifacts, control chars)
        vietnamese_escaped = re.escape(self.vietnamese_chars)
        allowed_pattern = f"[a-zA-Z0-9{vietnamese_escaped}\\s.,;:!?()\\[\\]{{}}\"'\\-_/\\\\+=%&@#$\\n\\t]"
        text = "".join(ch for ch in text if re.match(allowed_pattern, ch))

        # Normalize line breaks around punctuation (fix OCR issues)
        text = re.sub(r"\s*([.,;:!?])\s*", r"\1 ", text)
        text = re.sub(r"\s+", " ", text)

        cleaned_text = text.strip()

        # Metadata for tracking
        cleaning_metadata = {
            "was_empty": False,
            "original_length": original_length,
            "cleaned_length": len(cleaned_text),
            "word_count": len(cleaned_text.split()),
            "paragraphs": cleaned_text.count("\n\n") + 1,
            "estimated_tokens": self.estimate_tokens(cleaned_text),
        }

        return cleaned_text, cleaning_metadata

    def estimate_tokens(self, text: str) -> int:
        """Rough token estimation (tune per tokenizer)."""
        if not text:
            return 0
        return len(text) // CHARS_PER_TOKEN + 1


class ChunkMetadataGenerator:
    """Generates core metadata for chunks"""

    def __init__(self):
        self.session_id = str(uuid.uuid4())[:8]
        self.chunk_counter = 0

    def generate_chunk_id(self, source_file: str, chunk_index: int) -> str:
        """Generate unique chunk ID"""
        file_hash = hashlib.md5(source_file.encode()).hexdigest()[:8]
        return f"CHUNK_{self.session_id}_{file_hash}_{chunk_index:04d}"

    def generate_document_id(self, source_file: str) -> str:
        """Generate unique document ID"""
        file_hash = hashlib.md5(source_file.encode()).hexdigest()[:8]
        return f"DOC_{self.session_id}_{file_hash}"

    def create_metadata(
        self,
        doc: Document,
        source_file: str,
        chunk_index: int,
        cleaning_metadata: Dict[str, Any],
        estimated_tokens: int,
    ) -> Dict[str, Any]:
        """Create chunk metadata"""
        gc.collect()
        self.chunk_counter += 1

        # Generate IDs
        chunk_id = self.generate_chunk_id(source_file, chunk_index)
        document_id = self.generate_document_id(source_file)

        # Extract file information
        file_path = Path(source_file)
        content = doc.page_content
        content_hash = hashlib.sha256(content.encode()).hexdigest()

        # Core metadata
        metadata = {
            # Essential IDs
            "chunk_id": chunk_id,
            "document_id": document_id,
            "chunk_index": chunk_index,
            "global_index": self.chunk_counter,
            # Processing info
            "session_id": self.session_id,
            "timestamp": datetime.datetime.now().isoformat(),
            "chunker_model": CHUNKER_MODEL,
            # Source info
            "source_file": source_file,
            "source_filename": str(file_path.name),
            # Content metrics
            "content_length": len(content),
            "estimated_tokens": estimated_tokens,
            "word_count": cleaning_metadata.get("word_count", 0),
            "content_hash": content_hash,
            # # Future use placeholders
            # "summary": None,
            # "keywords": None,
            # "embedding_id": None,
            # "last_accessed": None,
        }

        return metadata

### 1.2 Run PDF Processing Pipeline

In [None]:
text_cleaner = TextCleaner()
metadata_generator = ChunkMetadataGenerator()

pdf_files: List[str] = []
all_documents: List[Document] = []

In [None]:
chunker = HybridChunker(
    tokenizer=CHUNKER_MODEL,
    max_tokens=MAX_CHUNK_TOKENS,
    overlap_tokens=OVERLAP_TOKENS,
)

# Process each PDF file
pbar = tqdm(total=len(pdf_files), desc="Processing PDFs")

for pdf_file in pdf_files:
    pbar.set_description(f"Processing {pdf_file.name}")
    pdf_file_str = str(pdf_file)

    try:
        # Use DoclingLoader with chunker
        loader = DoclingLoader(
            file_path=pdf_file_str,
            export_type=ExportType.DOC_CHUNKS,
            chunker=chunker,
        )

        # Load documents - already chunked
        docs = loader.load()

        # Process each chunk
        file_chunks_processed = 0

        for i, doc in enumerate(docs):
            # Clean text
            cleaned_content, cleaning_metadata = text_cleaner.clean_text(
                doc.page_content
            )

            if cleaned_content and len(cleaned_content) > 50:
                estimated_tokens = text_cleaner.estimate_tokens(cleaned_content)

                # Create metadata
                metadata = metadata_generator.create_metadata(
                    doc=doc,
                    source_file=pdf_file_str,
                    chunk_index=i,
                    cleaning_metadata=cleaning_metadata,
                    estimated_tokens=estimated_tokens,
                )

                # Update document with cleaned content and metadata
                doc.page_content = cleaned_content
                doc.metadata = metadata

                all_documents.append(doc)
                file_chunks_processed += 1

        del loader
        pbar.update(1)

    except Exception as e:
        pbar.update(1)
        continue

del chunker
pbar.close()

In [None]:
# Return out: `all_documents`
display(all_documents[:2])

___
## 2. Build RAG

In [None]:
from langchain.schema import Document
from typing import Union
from pymilvus import (
    AnnSearchRequest,
    Collection,
    CollectionSchema,
    DataType,
    FieldSchema,
    MilvusClient,
    RRFRanker,
    connections,
)
from pymilvus.milvus_client.index import IndexParams
from FlagEmbedding import BGEM3FlagModel

### 2.1 Milvus Client Connection

In [None]:
class MilvusConnectionManager:
    """Manages Milvus connections with Docker support"""

    def __init__(self):
        self.client: Optional[MilvusClient] = None
        self.supports_hnsw = False

    def get_client(self) -> Tuple[MilvusClient, bool]:
        """Get Milvus client with automatic connection management"""
        if self.client is not None:
            return self.client, self.supports_hnsw

        # Try Docker connection first if enabled
        if USE_DOCKER_MILVUS:
            try:
                print("🐳 Connecting to Docker Milvus...")
                self.client = MilvusClient(uri=MILVUS_DOCKER_URI)
                # Test connection
                self.client.list_collections()
                self.supports_hnsw = True
                print("✅ Docker Milvus connected (HNSW enabled)")
                return self.client, True
            except Exception as e:
                print(f"⚠️ Docker connection failed: {e}")

        # Fallback to local file-based connection
        print("📁 Using local file-based Milvus...")
        self.client = MilvusClient(uri=str(MILVUS_URI))
        self.supports_hnsw = False
        print("✅ Local Milvus connected")
        return self.client, False

    def get_index_config(self):
        """Get appropriate index configuration"""
        if self.supports_hnsw:
            return DENSE_INDEX_CONFIG, DENSE_SEARCH_PARAMS
        else:
            return DENSE_INDEX_FALLBACK_CONFIG, DENSE_SEARCH_FALLBACK_PARAMS


# Global connection manager
_connection_manager = MilvusConnectionManager()


def get_milvus_client() -> Tuple[MilvusClient, bool]:
    """Get the global Milvus client instance"""
    return _connection_manager.get_client()


def get_index_config():
    """Get appropriate index configuration for current connection"""
    return _connection_manager.get_index_config()

### 2.2 Milvus Builder and Search

In [None]:
def ensure_hybrid_collection(client: MilvusClient, name: str, dense_dim: int) -> bool:
    """
    Ensure hybrid collection exists with proper schema.

    Returns:
        True if collection was created/recreated, False if it already existed
    """
    if client.has_collection(name):
        client.drop_collection(name)

    schema = CollectionSchema(
        fields=[
            FieldSchema(
                name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=64
            ),
            FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
            FieldSchema(
                name="dense_vector", dtype=DataType.FLOAT_VECTOR, dim=dense_dim
            ),
            FieldSchema(name="sparse_vector", dtype=DataType.SPARSE_FLOAT_VECTOR),
            FieldSchema(name="metadata", dtype=DataType.JSON),
        ],
        description="Hybrid dense+sparse collection for BGE-M3",
    )

    client.create_collection(
        collection_name=name, schema=schema, consistency_level="Strong"
    )

    return True


def build_indexes(client: MilvusClient, name: str) -> None:
    """Build indexes with automatic HNSW/IVF_FLAT selection"""
    try:
        # Get appropriate index configuration
        dense_index_config, dense_search_params = get_index_config()

        # Dense index with auto-detection
        dense_index_params = IndexParams()
        dense_index_params.add_index(
            field_name="dense_vector",
            index_type=dense_index_config["index_type"],
            metric_type=dense_index_config["metric_type"],
            params=dense_index_config["params"],
        )
        client.create_index(
            collection_name=name,
            index_params=dense_index_params,
        )

        # Sparse index using config
        sparse_index_params = IndexParams()
        sparse_index_params.add_index(
            field_name="sparse_vector",
            index_type=SPARSE_INDEX_CONFIG["index_type"],
            metric_type=SPARSE_INDEX_CONFIG["metric_type"],
            params=SPARSE_INDEX_CONFIG["params"],
        )
        client.create_index(
            collection_name=name,
            index_params=sparse_index_params,
        )
        client.load_collection(name)

    except Exception as e:
        raise


def insert_documents(
    client: MilvusClient,
    name: str,
    dense_vecs,
    sparse_vecs,
    docs: List[Document],
) -> None:
    """Insert documents with better error handling"""
    try:
        rows = []
        for i, doc in enumerate(docs):
            rows.append(
                {
                    "id": uuid.uuid4().hex,
                    "text": doc.page_content,
                    "dense_vector": dense_vecs[i].tolist(),
                    "sparse_vector": sparse_vecs[i],  # dict {token_id: weight}
                    "metadata": doc.metadata or {},
                }
            )

        result = client.insert(collection_name=name, data=rows)

    except Exception as e:
        raise


def hybrid_search(
    uri: str,
    name: str,
    dense_q,
    sparse_q,
    k: int = 20,
) -> List[Tuple[Document, float]]:
    """
    Perform hybrid search with automatic parameter selection
    """
    try:
        # Ensure connection for Collection API
        connection_alias = "default"
        try:
            connections.get_connection(alias=connection_alias)
        except Exception:
            connections.connect(alias=connection_alias, uri=uri)

        coll = Collection(name)

        # Get appropriate search parameters
        _, dense_search_params = get_index_config()

        # Create search requests using appropriate parameters
        dense_req = AnnSearchRequest(
            data=[dense_q],
            anns_field="dense_vector",
            param=dense_search_params,
            limit=k,
        )

        sparse_req = AnnSearchRequest(
            data=[sparse_q],
            anns_field="sparse_vector",
            param=SPARSE_SEARCH_PARAMS,
            limit=k,
        )

        # Create RRF ranker using config
        ranker = RRFRanker(k=RRF_K)

        # Perform hybrid search
        res = coll.hybrid_search(
            reqs=[dense_req, sparse_req],
            rerank=ranker,
            limit=k,
            output_fields=["text", "metadata"],
        )

        # Convert to (Document, score)
        out: List[Tuple[Document, float]] = []
        for hit in res[0]:
            # PyMilvus hit object structure
            meta = hit.entity.get("metadata", {})
            text = hit.entity.get("text", "")
            score = float(hit.distance)
            out.append((Document(page_content=text, metadata=meta), score))

        return out

    except Exception as e:
        print(f"❌ Error in hybrid search: {e}")
        return []


def search_dense(
    client: MilvusClient,
    collection_name: str,
    query_embedding: List[float],
    k: int = 20,
    search_params: Optional[Dict] = None,
) -> List[Dict]:
    """
    Perform dense vector search.

    Args:
        client: Milvus client
        collection_name: Name of the collection
        query_embedding: Dense query embedding
        k: Number of results to return
        search_params: Search parameters

    Returns:
        List of search results
    """
    try:
        search_params = search_params or DENSE_SEARCH_PARAMS

        results = client.search(
            collection_name=collection_name,
            data=[query_embedding],
            anns_field="dense_vector",
            search_params=search_params,
            limit=k,
            output_fields=["text", "metadata"],
        )

        # Convert to standard format
        formatted_results = []
        for i, hit in enumerate(results[0]):
            formatted_results.append(
                {
                    "id": hit.id,  # Use attribute, not dict access
                    "dense_score": float(hit.distance),  # Use attribute directly
                    "rank": i + 1,
                    "entity": hit.entity,  # hit.entity is already a dict
                }
            )

        return formatted_results

    except Exception as e:
        return []


def search_sparse(
    client: MilvusClient,
    collection_name: str,
    query_sparse: Dict,
    k: int = 20,
) -> List[Dict]:
    """
    Perform sparse vector search.

    Args:
        client: Milvus client
        collection_name: Name of the collection
        query_sparse: Sparse query embedding (dict of token_id -> weight)
        k: Number of results to return

    Returns:
        List of search results
    """
    try:
        results = client.search(
            collection_name=collection_name,
            data=[query_sparse],
            anns_field="sparse_vector",
            search_params=SPARSE_SEARCH_PARAMS,
            limit=k,
            output_fields=["text", "metadata"],
        )

        # Convert to standard format
        formatted_results = []
        for i, hit in enumerate(results[0]):
            formatted_results.append(
                {
                    "id": hit.id,  # Use attribute, not dict access
                    "sparse_score": float(hit.distance),  # Use attribute directly
                    "rank": i + 1,
                    "entity": hit.entity,  # hit.entity is already a dict
                }
            )

        return formatted_results

    except Exception as e:
        return []


def reciprocal_rank_fusion(
    dense_results: List[Dict],
    sparse_results: List[Dict],
    k: int = 10,
    dense_weight: float = 0.7,
    sparse_weight: float = 0.3,
) -> List[Dict]:
    """
    Simplified RRF fusion with better scoring.
    """
    try:
        doc_scores = {}
        all_docs = {}

        # Process dense results
        for rank, result in enumerate(dense_results, 1):
            doc_id = result.get("id")
            if doc_id:
                rrf_score = dense_weight / (k + rank)
                doc_scores[doc_id] = doc_scores.get(doc_id, 0) + rrf_score
                all_docs[doc_id] = result
                all_docs[doc_id]["dense_score"] = result.get("dense_score", 0)

        # Process sparse results
        for rank, result in enumerate(sparse_results, 1):
            doc_id = result.get("id")
            if doc_id:
                rrf_score = sparse_weight / (k + rank)
                doc_scores[doc_id] = doc_scores.get(doc_id, 0) + rrf_score

                if doc_id in all_docs:
                    all_docs[doc_id]["sparse_score"] = result.get("sparse_score", 0)
                else:
                    all_docs[doc_id] = result
                    all_docs[doc_id]["sparse_score"] = result.get("sparse_score", 0)
                    all_docs[doc_id]["dense_score"] = 0

        # Normalize scores to 0-1 range
        if doc_scores:
            max_score = max(doc_scores.values())
            min_score = min(doc_scores.values())
            score_range = max_score - min_score

            if score_range > 0:
                for doc_id in doc_scores:
                    normalized = (doc_scores[doc_id] - min_score) / score_range
                    doc_scores[doc_id] = 0.1 + 0.9 * normalized
            else:
                for doc_id in doc_scores:
                    doc_scores[doc_id] = 0.5

        # Sort and format results
        sorted_docs = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)

        fused_results = []
        for doc_id, rrf_score in sorted_docs:
            doc = all_docs[doc_id].copy()
            doc["rrf_score"] = rrf_score
            doc["combined_score"] = rrf_score
            fused_results.append(doc)

        return fused_results

    except Exception as e:
        print(f"❌ Error in RRF fusion: {e}")
        return []

### 2.3 Embedding Model and Builder Vector Store

In [None]:
class BGEM3Encoder:
    """Encode text with BGE-M3 producing dense and sparse vectors."""

    def __init__(
        self,
        model: str = "BAAI/bge-m3",
        device: str = "cpu",
        normalize_embeddings: bool = True,
        use_fp16: bool = False,  # Set to False to avoid dtype mismatches
        max_length: int = 512,
        batch_size: int = 32,
        trust_remote_code: bool = True,
    ) -> None:
        self.model_id = model
        self.device = device
        self.normalize_embeddings = normalize_embeddings
        self.use_fp16 = use_fp16
        self.max_length = max_length
        self.batch_size = batch_size
        self.model = BGEM3FlagModel(
            model,
            device=device,
            use_fp16=use_fp16,
            trust_remote_code=trust_remote_code,
        )

    def encode(
        self, text_or_texts: Union[str, List[str]], batch_size: Optional[int] = None
    ) -> Dict[str, List]:
        """
        Encode text(s) to dense and sparse vectors.

        Args:
            text_or_texts: Single text string or list of texts
            batch_size: Batch size for encoding

        Returns:
            Dictionary with 'dense_vecs' and 'lexical_weights' keys
        """
        try:
            if isinstance(text_or_texts, str):
                texts = [text_or_texts]
            else:
                texts = text_or_texts

            if not texts:
                raise ValueError("No texts provided for encoding")

            out = self.model.encode(
                sentences=texts,
                batch_size=batch_size or self.batch_size,
                max_length=self.max_length,
                return_dense=True,
                return_sparse=True,
                return_colbert_vecs=False,
            )

            # Ensure consistent data types (convert to float32 if needed)
            if "dense_vecs" in out:
                import numpy as np

                out["dense_vecs"] = out["dense_vecs"].astype(np.float32)

            # Validate output structure
            if "dense_vecs" not in out or "lexical_weights" not in out:
                raise ValueError("Invalid output from BGE-M3 model")

            # out keys: "dense_vecs": np.ndarray [n, d], "lexical_weights": List[Dict[token_id->weight]]
            return out

        except Exception as e:
            print(f"❌ Error encoding with BGE-M3: {e}")
            # Return empty structure on error
            import numpy as np

            return {
                "dense_vecs": np.array([]).astype(np.float32),
                "lexical_weights": [],
            }

    def encode_query(self, text: str) -> Dict[str, List]:
        """Encode a single query text."""
        return self.encode(text)


def load_encoder() -> BGEM3Encoder:
    """Load BGE-M3 encoder"""
    return BGEM3Encoder(
        model=EMBED_MODEL_ID,
        device="cpu",
        normalize_embeddings=ENCODE_KWARGS.get("normalize_embeddings", True),
        batch_size=ENCODE_KWARGS.get("batch_size", 32),
    )


def build_hybrid_vectorstore(documents: List[Document], batch_size: int = 32) -> bool:
    """Build hybrid vectorstore with automatic Docker/local detection"""
    if not documents:
        print("❌ No documents to process")
        return False

    print(f"🔄 Building hybrid vectorstore with {len(documents)} documents...")

    # Get client with automatic connection management
    client, supports_hnsw = get_milvus_client()

    # Create collection
    ensure_hybrid_collection(client, COLLECTION_NAME, EMBEDDING_DIM)

    # Load encoder
    encoder = load_encoder()

    # Process documents in batches
    total = len(documents)
    total_batches = (total + batch_size - 1) // batch_size

    for batch_idx, start in enumerate(range(0, total, batch_size)):
        end = min(start + batch_size, total)
        batch_docs = documents[start:end]
        texts = [d.page_content for d in batch_docs]

        print(
            f"📦 Processing batch {batch_idx + 1}/{total_batches} ({len(batch_docs)} docs)"
        )

        try:
            # Encode texts
            emb = encoder.encode(texts, batch_size=len(texts))
            dense = emb["dense_vecs"]
            sparse = emb["lexical_weights"]

            # Insert into Milvus
            insert_documents(client, COLLECTION_NAME, dense, sparse, batch_docs)

            # Clean up memory
            del emb, dense, sparse, batch_docs, texts
            gc.collect()

        except Exception as e:
            print(f"❌ Error processing batch {batch_idx + 1}: {e}")
            return False

    # Build indexes
    print("🔧 Building indexes...")
    build_indexes(client, COLLECTION_NAME)

    index_type = "HNSW" if supports_hnsw else "IVF_FLAT"
    print(f"✅ Hybrid vectorstore built successfully with {index_type} indexing!")
    return True

In [None]:
success = build_hybrid_vectorstore(all_documents, batch_size=32)

___
## 3. Retrieve from RAG

In [None]:
from abc import ABC, abstractmethod
import os
from dotenv import load_dotenv
from google.genai import Client, types
from FlagEmbedding import FlagReranker

### 3.1 LLM Definition and Call Managements

In [None]:
class BaseLLM(ABC):
    """Base class for all LLM implementations"""

    def __init__(
        self,
        model_id: str,
        api_key: Optional[str] = None,
        max_tokens: int = MAX_OUTPUT_TOKENS,
        temperature: float = TEMPERATURE,
        **kwargs,
    ):
        """
        Initialize LLM with common parameters.

        Args:
            model_id: Model identifier
            api_key: API key for the service
            max_tokens: Maximum tokens to generate
            temperature: Sampling temperature
        """
        self.model_id = model_id
        self.api_key = api_key
        self.max_tokens = max_tokens
        self.temperature = temperature

    @abstractmethod
    def generate(self, prompt: str) -> str:
        """
        Generate response from prompt.

        Args:
            prompt: Input prompt

        Returns:
            Generated text response
        """
        pass


class GeminiLLM:
    """Google Gemini LLM implementation"""

    def __init__(
        self,
        model_id: str = "gemini-2.5-flash",
        api_key: Optional[str] = None,
        max_tokens: int = MAX_OUTPUT_TOKENS,
        temperature: float = TEMPERATURE,
        **kwargs,
    ):
        """
        Initialize Gemini LLM.

        Args:
            model_id: Gemini model identifier
            api_key: API key (gets from env if None)
            max_tokens: Maximum tokens to generate
            temperature: Sampling temperature
        """
        load_dotenv()
        self.model_id = model_id
        self.api_key = api_key or os.getenv("GEMINI_API_KEY")

        if not self.api_key:
            raise ValueError("GEMINI_API_KEY is required")

        self.llm = Client(api_key=self.api_key)
        self.config = types.GenerateContentConfig(
            system_instruction="You are a helpful assistant.",
            max_output_tokens=max_tokens,
            temperature=temperature,
            # thinking_config=types.ThinkingConfig(
            #     max_steps=5,
            #     stop_sequences=["\n"],
            # ),
        )

    def generate(self, prompt: str) -> str:
        """
        Generate response from prompt.

        Args:
            prompt: Input prompt

        Returns:
            Generated text response
        """
        try:
            response = self.llm.models.generate_content(
                model=self.model_id,
                contents=prompt,
                config=self.config,
            )
            result = response.text.strip()
            return result

        except Exception as e:
            return f"Error generating response: {e}"


class LLMFactory:
    """Factory for creating LLM model instances"""

    # Register available models here
    MODELS = {
        "gemini": GeminiLLM,
        # "watsonx": WatsonxLLM,
        # "your_model": YourModelLLM,
        # Add more models here
    }

    @classmethod
    def create_llm(
        self, model_type: str = "gemini", model_id: Optional[str] = None, **kwargs
    ) -> BaseLLM:
        """
        Create an LLM instance.

        Args:
            model_type: Type of model ("gemini", "your_model", etc.)
            model_id: Specific model ID (uses default if None)
            **kwargs: Additional parameters for the model

        Returns:
            Initialized LLM instance

        Raises:
            ValueError: If model_type is not supported
        """
        if model_type not in self.MODELS:
            available = ", ".join(self.MODELS.keys())
            raise ValueError(
                f"Model '{model_type}' not supported. Available: {available}"
            )

        model_class = self.MODELS[model_type]

        # Set default model_id based on type
        if model_id is None:
            defaults = {
                "gemini": "gemini-2.5-flash",
                # "watsonx": "ibm/granite-13b-chat-v2",
                # Add defaults for other models
            }
            model_id = defaults.get(model_type, "default-model")

        return model_class(model_id=model_id, **kwargs)

    @classmethod
    def list_available_models(cls) -> list:
        """Get list of available model types"""
        return list(cls.MODELS.keys())

In [None]:
class RAGLLMCaller:
    """
    All-in-one LLM caller: prepare prompt, init model, call LLM, return structured output.

    Now supports multiple model types through the factory system.
    """

    def __init__(self, model_type: str = "gemini", **model_kwargs):
        """
        Initialize RAG LLM caller.

        Args:
            model_type: Type of model to use ("gemini", etc.)
            **model_kwargs: Additional parameters for the model
        """
        self.model_type = model_type
        self.model_kwargs = model_kwargs
        self.llm: Optional[BaseLLM] = None

    def _init_llm(self, **kwargs):
        """Initialize LLM if not already done"""
        if self.llm is None:
            # Combine init kwargs with runtime kwargs
            combined_kwargs = {**self.model_kwargs, **kwargs}
            self.llm = LLMFactory.create_llm(self.model_type, **combined_kwargs)

    def _prepare_context(self, documents: List[Document]) -> str:
        """Format documents into context string"""
        if not documents:
            return "Không tìm thấy thông tin liên quan."

        context_parts = []
        for i, doc in enumerate(documents, 1):
            content = doc.page_content.strip()
            if content:
                context_parts.append(f"[Tài liệu {i}]:\n{content}")

        return "\n\n".join(context_parts)

    def _prepare_prompt(self, query: str, context: str) -> str:
        """Create final prompt from query and context"""
        return PROMPT.format(context=context, query=query)

    def generate_answer(
        self, query: str, documents: List[Document], **llm_kwargs
    ) -> Dict[str, Any]:
        """
        Complete pipeline: documents → context → prompt → LLM → structured output

        Args:
            query: User question
            documents: Retrieved documents
            **llm_kwargs: Optional LLM parameters (model_id, max_tokens, temperature)

        Returns:
            Dict with answer, context_length, success status
        """
        try:
            # Step 1: Prepare context
            context = self._prepare_context(documents)

            # Step 2: Prepare prompt
            prompt = self._prepare_prompt(query, context)

            # Step 3: Initialize LLM
            self._init_llm(**llm_kwargs)

            # Step 4: Call model
            answer = self.llm.generate(prompt)

            # Step 5: Return structured output
            return {
                "answer": answer,
                "context_length": len(context),
                "model_type": self.model_type,
                "success": True,
                "error": None,
            }

        except Exception as e:
            return {
                "answer": f"Xin lỗi, có lỗi xảy ra: {str(e)}",
                "context_length": 0,
                "model_type": self.model_type,
                "success": False,
                "error": str(e),
            }


# Global instance with default model
_rag_llm_caller = None


def get_rag_llm_caller(model_type: str = "gemini", **kwargs) -> RAGLLMCaller:
    """
    Get global RAG LLM caller instance.

    Args:
        model_type: Type of model to use
        **kwargs: Model parameters

    Returns:
        RAGLLMCaller instance
    """
    global _rag_llm_caller
    if _rag_llm_caller is None or _rag_llm_caller.model_type != model_type:
        _rag_llm_caller = RAGLLMCaller(model_type=model_type, **kwargs)
    return _rag_llm_caller

### 3.2 Load retriever and models

In [None]:
def load_retrieval_models() -> Tuple[Any, Any, Any]:
    """
    Load and initialize all models needed for RAG retrieval with PyMilvus.

    Returns:
        Tuple containing: Milvus client, embedding model, and reranker model
    """
    # Clear memory before loading models
    gc.collect()

    # Load BGE-M3 embedding model
    embedding_model = BGEM3Encoder(
        model=EMBED_MODEL_ID,
        device="cpu",  # Force CPU to avoid memory conflicts
        normalize_embeddings=True,
        use_fp16=False,  # Set to False to avoid dtype issues
        max_length=512,
        batch_size=16,  # Smaller batch size for memory efficiency
    )

    # Test embedding model with minimal example
    test_result = embedding_model.encode("test")

    # Clean up test embedding immediately
    del test_result
    gc.collect()

    # Initialize Milvus client
    client = MilvusClient(uri=str(MILVUS_URI))

    # Test connection
    if client.has_collection(COLLECTION_NAME):
        pass
    else:
        raise

    # Load BGE reranker model
    reranker_model = FlagReranker(
        RERANKER_MODEL_ID,
        normalize=True,
        use_fp16=False,  # Disable FP16 to avoid memory issues
        device="cpu",  # Force CPU to avoid CUDA OOM
    )

    # Final memory cleanup
    gc.collect()

    return client, embedding_model, reranker_model


def get_embedding_model():
    """Load and return BGE-M3 embedding model."""

    # Clear memory before loading models
    gc.collect()

    embedding_model = BGEM3Encoder(
        model=EMBED_MODEL_ID,
        device="cpu",  # Force CPU to avoid memory conflicts
        normalize_embeddings=True,
        use_fp16=False,  # Set to False to avoid dtype issues
        max_length=512,
        batch_size=16,  # Smaller batch size for memory efficiency
    )

    return embedding_model


def get_reranker_model():
    """Load and return BGE reranker model."""

    reranker_model = FlagReranker(
        RERANKER_MODEL_ID,
        normalize=True,
        use_fp16=False,  # Disable FP16 to avoid memory issues
        device="cpu",  # Force CPU to avoid CUDA OOM
    )

    return reranker_model

In [None]:
class DocumentRetriever:
    """Internal subclass for document retrieval operations."""

    def __init__(self, parent_rag):
        """Initialize retriever with reference to parent RAG system."""
        self.parent = parent_rag

    def retrieve_and_rerank(self, query: str) -> Tuple[List[Document], List[float]]:
        """
        Retrieve and rerank documents for a query.

        Args:
            query: Search query

        Returns:
            Tuple of (documents, rerank_scores)
        """
        try:
            # Generate embeddings first
            embeddings = self.parent.embedding_model.encode([query])
            query_embedding = embeddings["dense_vecs"][0]
            query_sparse = embeddings["lexical_weights"][0]

            # Perform hybrid search with embeddings
            search_results = hybrid_search(
                client=self.parent.client,
                collection_name=self.parent.collection_name,
                query_embedding=query_embedding,
                query_sparse=query_sparse,
                k=self.parent.k,
                similarity_threshold=self.parent.similarity_threshold,
            )

            # Convert search results to documents and perform reranking
            documents = []
            scores = []
            for result in search_results:
                # Extract text and metadata from entity field
                entity = result.get("entity", {})
                text_content = entity.get("text", "")
                metadata = entity.get("metadata", {})

                doc = Document(
                    page_content=text_content,
                    metadata=metadata,
                )
                documents.append(doc)
                # Use the combined score from RRF or the available score
                score = result.get(
                    "rrf_score",
                    result.get("combined_score", result.get("dense_score", 0.0)),
                )
                scores.append(score)

            # Perform reranking if we have documents and a reranker model
            if documents and self.parent.reranker_model:
                try:
                    # Prepare texts for reranking
                    texts = [doc.page_content for doc in documents]

                    # Rerank documents
                    rerank_results = self.parent.reranker_model.compute_score(
                        [[query, text] for text in texts]
                    )

                    # Get rerank scores and sort
                    if isinstance(rerank_results, list):
                        rerank_scores = rerank_results
                    else:
                        rerank_scores = rerank_results.tolist()

                    # Sort by rerank scores and take top k
                    scored_docs = list(zip(documents, rerank_scores))
                    scored_docs.sort(key=lambda x: x[1], reverse=True)

                    # Take top rerank_top_k results
                    top_k = min(len(scored_docs), self.parent.rerank_top_k)
                    top_docs = scored_docs[:top_k]

                    documents = [doc for doc, _ in top_docs]
                    rerank_scores = [score for _, score in top_docs]

                except Exception as e:
                    rerank_scores = scores[: self.parent.rerank_top_k]
                    documents = documents[: self.parent.rerank_top_k]
            else:
                # No reranking, just take top results
                top_k = min(len(documents), self.parent.rerank_top_k)
                documents = documents[:top_k]
                rerank_scores = scores[:top_k]

            return documents, rerank_scores

        except Exception as e:
            return [], []


class AnswerGenerator:
    """Internal subclass for answer generation operations."""

    def __init__(self, parent_rag):
        """Initialize generator with reference to parent RAG system."""
        self.parent = parent_rag

    def generate_answer(
        self,
        query: str,
        documents: List[Document],
        rerank_scores: Optional[List[float]] = None,
        **llm_kwargs,
    ) -> Dict:
        """
        Generate answer from query and retrieved documents.

        Args:
            query: User question
            documents: Retrieved documents
            rerank_scores: Reranking scores
            **llm_kwargs: Additional LLM parameters

        Returns:
            Dict with answer, sources, confidence, etc.
        """
        if not documents:
            return {
                "answer": "Tôi không tìm thấy thông tin phù hợp để trả lời câu hỏi của bạn.",
                "sources": [],
                "confidence": 0.0,
                "success": False,
                "retrieval_count": 0,
            }

        try:
            # Generate answer using LLM
            result = self.parent.llm_caller.generate_answer(
                query=query,
                documents=documents,
                rerank_scores=rerank_scores,
                **llm_kwargs,
            )

            # Add retrieval metadata
            result["retrieval_count"] = len(documents)
            result["success"] = True

            return result

        except Exception as e:
            return {
                "answer": "Xin lỗi, đã có lỗi xảy ra khi xử lý câu hỏi của bạn.",
                "sources": [],
                "confidence": 0.0,
                "success": False,
                "error": str(e),
                "retrieval_count": len(documents),
            }

    def switch_model(self, model_type: str, **model_kwargs):
        """
        Switch LLM model.

        Args:
            model_type: New model type
            **model_kwargs: Model parameters
        """
        try:
            # Create new LLM caller with specified model
            self.parent.llm_caller = get_rag_llm_caller(
                model_type=model_type, **model_kwargs
            )

            # Update parent's model type tracking
            self.parent._current_model_type = model_type
        except Exception as e:
            raise


class VietnameseRAG:
    """
    Unified Vietnamese RAG system with modular subclasses.

    This class provides the complete RAG flow while using internal
    subclasses for clean modular architecture.
    """

    def __init__(
        self,
        # Connection parameters
        client=None,
        collection_name: str = COLLECTION_NAME,
        # Model parameters
        embedding_model=None,
        reranker_model=None,
        # Retrieval parameters
        k: int = DEFAULT_K,
        rerank_top_k: int = RERANK_TOP_K,
        similarity_threshold: float = SIMILARITY_THRESHOLD,
        # LLM parameters
        model_type: str = "gemini",
        llm_caller: Optional[RAGLLMCaller] = None,
        **llm_kwargs,
    ):
        """
        Initialize unified Vietnamese RAG system.

        Args:
            client: Milvus client
            collection_name: Collection name in Milvus
            embedding_model: BGE-M3 embedding model
            reranker_model: Reranker model
            k: Initial retrieval count
            rerank_top_k: Final count after reranking
            similarity_threshold: Minimum similarity threshold
            model_type: LLM model type (gemini, watsonx, etc.)
            llm_caller: Custom LLM caller instance
            **llm_kwargs: Additional LLM parameters
        """
        # Store configuration
        self.collection_name = collection_name
        self.k = k
        self.rerank_top_k = rerank_top_k
        self.similarity_threshold = similarity_threshold
        self._current_model_type = model_type

        # Initialize Milvus client
        self.client = client or get_milvus_client()

        # Load models
        self.embedding_model = embedding_model or get_embedding_model()
        self.reranker_model = reranker_model or get_reranker_model()

        # Initialize LLM caller
        self.llm_caller = llm_caller or get_rag_llm_caller(
            model_type=model_type, **llm_kwargs
        )

        # Initialize modular subclasses
        self.retriever = DocumentRetriever(self)
        self.generator = AnswerGenerator(self)

    def answer(self, query: str, **llm_kwargs) -> Dict:
        """
        Main method: Complete RAG flow from query to answer.

        Args:
            query: User question
            **llm_kwargs: Additional LLM parameters

        Returns:
            Complete result with answer, sources, confidence, etc.
        """
        try:
            # Step 1: Retrieve and rerank documents
            documents, rerank_scores = self.retriever.retrieve_and_rerank(query)

            if not documents:
                return {
                    "answer": "Tôi không tìm thấy thông tin phù hợp để trả lời câu hỏi của bạn.",
                    "sources": [],
                    "confidence": 0.0,
                    "success": False,
                    "retrieval_count": 0,
                }

            # Step 2: Generate answer
            result = self.generator.generate_answer(
                query=query,
                documents=documents,
                rerank_scores=rerank_scores,
                **llm_kwargs,
            )

            return result

        except Exception as e:
            return {
                "answer": "Xin lỗi, đã có lỗi xảy ra khi xử lý câu hỏi của bạn.",
                "sources": [],
                "confidence": 0.0,
                "success": False,
                "error": str(e),
                "retrieval_count": 0,
            }

    def switch_model(self, model_type: str, **model_kwargs):
        """
        Switch LLM model.

        Args:
            model_type: New model type (gemini, watsonx, etc.)
            **model_kwargs: Model-specific parameters
        """
        self.generator.switch_model(model_type, **model_kwargs)

    @property
    def model_type(self) -> str:
        """Get current LLM model type."""
        return self._current_model_type

    @property
    def status(self) -> Dict[str, Any]:
        """Get system status and configuration."""
        return {
            "model_type": self.model_type,
            "collection_name": self.collection_name,
            "retrieval_config": {
                "k": self.k,
                "rerank_top_k": self.rerank_top_k,
                "similarity_threshold": self.similarity_threshold,
            },
            "models_loaded": {
                "embedding": self.embedding_model is not None,
                "reranker": self.reranker_model is not None,
                "llm": self.llm_caller is not None,
            },
            "milvus_connected": self.client is not None,
        }

    def update_config(
        self,
        k: Optional[int] = None,
        rerank_top_k: Optional[int] = None,
        similarity_threshold: Optional[float] = None,
    ):
        """
        Update retrieval configuration.

        Args:
            k: New initial retrieval count
            rerank_top_k: New reranking count
            similarity_threshold: New similarity threshold
        """
        if k is not None:
            self.k = k

        if rerank_top_k is not None:
            self.rerank_top_k = rerank_top_k

        if similarity_threshold is not None:
            self.similarity_threshold = similarity_threshold


# Convenience function for backward compatibility
def get_vietnamese_rag(**kwargs) -> VietnameseRAG:
    """
    Get Vietnamese RAG instance with default configuration.

    Args:
        **kwargs: Configuration parameters

    Returns:
        VietnameseRAG instance
    """
    return VietnameseRAG(**kwargs)

### 3.3 Define RAG Chain and Run

In [None]:
def perform_retrieval(rag, query: str) -> Tuple[List[Document], List[float]]:
    """
    Perform document retrieval and reranking only.

    Args:
        rag: The Vietnamese RAG system instance
        query: The search query

    Returns:
        Tuple of (documents, rerank_scores)
    """
    try:
        # Use the retriever directly for cleaner separation
        documents, scores = rag.retriever.retrieve_and_rerank(query)

        return documents, scores

    except Exception as e:
        return [], []


def generate_answer(rag, query: str, documents: List[Document]) -> str:
    """
    Generate an answer using LLM based on retrieved documents.

    Args:
        rag: The Vietnamese RAG system instance
        query: The search query
        documents: Retrieved documents

    Returns:
        Dictionary containing answer and metadata
    """
    if not documents:
        return {
            "success": False,
            "error": "No documents provided for answer generation",
            "answer": None,
        }

    try:
        # Use the answer generator directly
        answer_result = rag.answer_generator.generate_answer(query, documents)

        return {
            "success": True,
            "answer": answer_result.get("answer", ""),
            "confidence": answer_result.get("confidence", 0.0),
            "error": None,
        }

    except Exception as e:
        return {"success": False, "error": str(e), "answer": None}


def display_results(
    query: str,
    documents: List[Document],
    scores: List[float],
    answer_result: Dict[str, Any],
) -> None:
    """
    Display the retrieval and answer generation results.

    Args:
        query: The search query
        documents: Retrieved documents
        scores: Document scores
        answer_result: LLM answer generation result
    """
    if not documents:
        print("❌ No relevant documents found.")
        return

    # If answer generation succeeded, show only query and answer
    if answer_result.get("success") and answer_result.get("answer"):
        print(f"\n🔍 **Query:** {query}")
        print(f"\n🤖 **Answer:**")
        print("=" * 50)
        print(answer_result["answer"])

        # Log the successful Q&A pair
        retrieval_info = {
            "num_results": len(documents),
            "sources": [
                {
                    "filename": doc.metadata.get("source_filename", "Unknown"),
                    "score": score,
                    "content_length": len(doc.page_content),
                }
                for doc, score in zip(documents, scores)
            ],
            "confidence": answer_result.get("confidence", 0.0),
        }

    # If answer generation failed, show LLM calling prompt
    else:
        print(f"\n⚠️ **LLM calling failed or returned no answer**")
        if answer_result.get("error"):
            print(f"Error: {answer_result['error']}")

        print(f"\n🔍 Query: {query}")
        print(
            f"📊 Retrieved {len(documents)} documents but LLM answer generation failed."
        )

In [None]:
client, supports_hnsw = get_milvus_client()
embedding_model, reranker_model = load_retrieval_models()
rag = VietnameseRAG(
    client=client,
    collection_name=COLLECTION_NAME,
    embedding_model=embedding_model,
    reranker_model=reranker_model,
)

In [None]:
print("\n💡 Enter 'quit' to exit the system")
print("🤖 Complete RAG: BGE-M3 search → LLM answer generation")
print("=" * 50)

while True:
    try:
        # Get user input
        query = input("\n🔍 Enter your query: ").strip()

        # Check for quit command
        if query.lower() in ["quit", "q", "exit"]:
            print("👋 Goodbye!")
            break

        if not query:
            print("⚠️ Please enter a valid query")
            continue

        # Step 1: Perform retrieval
        documents, scores = perform_retrieval(rag, query)

        # Step 2: Generate answer (if documents found)
        answer_result = generate_answer(rag, query, documents)

        # Step 3: Display results
        display_results(query, documents, scores, answer_result)

        # Clean up after each query
        gc.collect()

    except KeyboardInterrupt:
        print("\n\n👋 Goodbye!")
        break
    except Exception as e:
        print(f"❌ Error processing query: {e}")
        continue