In [None]:
# Import necessary libraries
import os
import time
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List, Dict, Any, Optional, Union
import logging
from pymilvus import (
    connections, 
    utility,
    Collection,
    FieldSchema,
    CollectionSchema,
    DataType,
    Function,
    FunctionType,
    AnnSearchRequest,
    RRFRanker,
    WeightedRanker
)
from langchain_ollama import OllamaEmbeddings

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Connect to Milvus
def connect_to_milvus(host="localhost", port="19530", user="", password=""):
    """Connect to Milvus server"""
    try:
        connections.connect(
            alias="default",
            host=host,
            port=port,
            user=user,
            password=password
        )
        logger.info(f"Connected to Milvus server at {host}:{port}")
        return True
    except Exception as e:
        logger.error(f"Error connecting to Milvus: {str(e)}")
        return False

# Connect to Milvus server
connect_to_milvus()

In [None]:
# Define collection name and embedding dimension
collection_name = "document_store_hybrid"
embedding_dim = 3072  # Update this based on your embedding dimension

def create_hybrid_collection():
    """Create a new collection with schema supporting both vector and BM25 search"""
    # Drop collection if it exists
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
        logger.info(f"Dropped existing collection {collection_name}")
    
    # Define fields for collection
    fields = [
        FieldSchema(name="id", dtype=DataType.VARCHAR, max_length=100, is_primary=True),
        FieldSchema(name="document_id", dtype=DataType.VARCHAR, max_length=100),
        FieldSchema(name="case_id", dtype=DataType.VARCHAR, max_length=256),
        FieldSchema(name="chunk_id", dtype=DataType.VARCHAR, max_length=100),
        # Enable analyzer for full-text search on content field
        FieldSchema(
            name="content", 
            dtype=DataType.VARCHAR, 
            max_length=65535,
            enable_analyzer=True,  # This enables text analysis for BM25
            enable_match=True
        ),
        FieldSchema(name="content_type", dtype=DataType.VARCHAR, max_length=50),
        FieldSchema(name="chunk_type", dtype=DataType.VARCHAR, max_length=50),
        FieldSchema(name="page_number", dtype=DataType.INT64),
        FieldSchema(name="tree_level", dtype=DataType.INT64),
        FieldSchema(name="metadata", dtype=DataType.JSON),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=embedding_dim),
        # Field for sparse vectors (BM25)
        FieldSchema(name="sparse", dtype=DataType.SPARSE_FLOAT_VECTOR)
    ]
    
    # Create schema
    schema = CollectionSchema(fields=fields, description="Document store with hybrid search")
    
    # Create BM25 function to convert content to sparse vector
    bm25_function = Function(
        name="content_bm25_emb",  # Function name
        input_field_names=["content"],  # Name of the field containing text data
        output_field_names=["sparse"],  # Name of field to store sparse embeddings
        function_type=FunctionType.BM25  # Use BM25 function type
    )
    
    # Add function to schema
    schema.add_function(bm25_function)
    
    # Create collection
    collection = Collection(name=collection_name, schema=schema, shards_num=2)
    logger.info(f"Created collection {collection_name}")
    
    # Create vector index on embedding field
    index_params = {
        "index_type": "HNSW",
        "metric_type": "COSINE",
        "params": {"M": 16, "efConstruction": 128}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    logger.info("Created vector index on embedding field")
    
    sparse_index_params = {
    "field_name": "sparse",
    "index_name": "sparse_inverted_index",
    "index_type": "SPARSE_INVERTED_INDEX",   # or "AUTOINDEX"
    "metric_type": "BM25",                   # required for full-text
    "params": {"inverted_index_algo": "DAAT_MAXSCORE"}  # optional
    }
    collection.create_index(field_name="sparse", index_params=sparse_index_params)
    logger.info("Created sparse (BM25) index on field 'sparse'")
    
    
    # Load collection
    collection.load()
    logger.info(f"Collection {collection_name} loaded")
    
    return collection

# Create collection
collection = create_hybrid_collection()

In [None]:
def import_data_from_existing_collection(source_collection_name="document_store", limit=1000):
    """
    Import data from existing collection to the new hybrid collection
    
    This function reads data from your current vector store and inserts it
    into the new collection with BM25 support
    """
    try:
        # Connect to source collection
        source_collection = Collection(name=source_collection_name)
        source_collection.load()
        logger.info(f"Loaded source collection {source_collection_name}")
        
        # Query data from source collection
        results = source_collection.query(
            expr="",  # Empty string means all data
            output_fields=[
                "id", "document_id", "case_id", "chunk_id", "content", 
                "content_type", "chunk_type", "page_number", "tree_level", 
                "metadata", "embedding"
            ],
            limit=limit
        )
        
        if not results:
            logger.warning("No data found in source collection")
            return False
        
        logger.info(f"Retrieved {len(results)} documents from source collection")
        
        # Prepare data for insertion
        entities = []
        
        for entity in results:
            # Convert entity to format needed for insertion
            # Note: We don't need to provide sparse vector as it's generated automatically by BM25 function
            entities.append({
                "id": entity.get("id"),
                "document_id": entity.get("document_id"),
                "case_id": entity.get("case_id"),
                "chunk_id": entity.get("chunk_id"),
                "content": entity.get("content"),
                "content_type": entity.get("content_type", "text"),
                "chunk_type": entity.get("chunk_type", "original"),
                "page_number": entity.get("page_number", -1),
                "tree_level": entity.get("tree_level", 0),
                "metadata": entity.get("metadata", {}),
                "embedding": entity.get("embedding")
            })
        
        # Insert data into new collection
        target_collection = Collection(name=collection_name)
        result = target_collection.insert(entities)
        
        target_collection.flush()
        logger.info(f"Imported {len(entities)} entities from {source_collection_name} to {collection_name}")
        
        # Return the number of inserted entities
        return len(entities)
    
    except Exception as e:
        logger.error(f"Error importing data: {str(e)}")
        return 0

# Import data from existing collection
num_imported = import_data_from_existing_collection(limit=1000)
print(f"Imported {num_imported} documents")

In [None]:
def insert_sample_data(num_samples=10):
    """
    Insert sample data if no existing collection or import failed
    """
    import uuid
    from langchain_ollama import OllamaEmbeddings
    
    # Initialize embedding model
    embedding_model = OllamaEmbeddings(model="llama3.2")
    
    # Sample texts with varying content
    texts = [
        "The transformer architecture revolutionized natural language processing with its attention mechanism.",
        "Document retrieval systems use vector similarity to find relevant information.",
        "BM25 is a ranking function used in information retrieval to estimate document relevance.",
        "Hybrid search combines the benefits of semantic search and keyword-based retrieval.",
        "Large language models can generate coherent text based on prompts.",
        "Vector databases store and query high-dimensional embeddings efficiently.",
        "Knowledge graphs represent information as interconnected entities and relationships.",
        "Text chunking is crucial for effective document indexing and retrieval.",
        "Semantic search understands the meaning behind queries rather than just keywords.",
        "Retrieval augmented generation combines search with text generation for better results."
    ]
    
    # Generate embeddings for texts
    embeddings = embedding_model.embed_documents(texts)
    
    # Create entities for insertion
    entities = []
    
    for i, (text, embedding) in enumerate(zip(texts, embeddings)):
        # Create mock document data
        entity_id = f"sample_{i+1}_{uuid.uuid4().hex[:8]}"
        document_id = f"doc_{i//3 + 1}"  # Group by document (3 chunks per document)
        case_id = "case_test"  # Single test case
        chunk_id = f"chunk_{i+1}"
        
        entities.append({
            "id": entity_id,
            "document_id": document_id,
            "case_id": case_id,
            "chunk_id": chunk_id,
            "content": text,
            "content_type": "text",
            "chunk_type": "original",
            "page_number": i % 5,  # Mock page numbers
            "tree_level": 0,  # All are original chunks
            "metadata": {"source": "sample"},
            "embedding": embedding
        })
    
    # Insert data
    collection = Collection(name=collection_name)
    result = collection.insert(entities)
    collection.flush()
    
    logger.info(f"Inserted {len(entities)} sample entities")
    return len(entities)

# Check if data was imported, if not, insert sample data
collection = Collection(name=collection_name)
if collection.num_entities == 0:
    num_samples = insert_sample_data(num_samples=10)
    print(f"Added {num_samples} sample documents")
else:
    print(f"Collection already has {collection.num_entities} entities")

In [None]:
# 1. Vector Search (Original approach)
def vector_search(query, case_id, document_ids=None, top_k=5):
    """
    Pure vector search using embeddings
    
    Args:
        query: Search query
        case_id: Case ID to filter by
        document_ids: Optional list of document IDs to filter by
        top_k: Number of results to return
        
    Returns:
        List of search results and search time
    """
    start_time = time.time()
    
    # Generate embedding for query
    embedding_model = OllamaEmbeddings(model="llama3.2")
    query_embedding = embedding_model.embed_query(query)
    
    # Prepare search parameters
    search_params = {
        "metric_type": "COSINE",
        "params": {"ef": 64}
    }
    
    # Prepare filter expression
    expr_parts = []
    
    # Add case_id filter
    expr_parts.append(f'case_id == "{case_id}"')
    
    # Add document_ids filter if provided
    if document_ids and len(document_ids) > 0:
        if len(document_ids) == 1:
            expr_parts.append(f'document_id == "{document_ids[0]}"')
        else:
            doc_list = '", "'.join(document_ids)
            expr_parts.append(f'document_id in ["{doc_list}"]')
    
    # Combine filter expressions
    expr = " && ".join(expr_parts) if expr_parts else None
    
    # Execute search
    collection = Collection(name=collection_name)
    results = collection.search(
        data=[query_embedding],
        anns_field="embedding",
        param=search_params,
        limit=top_k,
        expr=expr,
        output_fields=["document_id", "chunk_id", "content", "content_type", 
                     "chunk_type", "page_number", "tree_level", "metadata"]
    )
    
    # Format results
    formatted_results = []
    for hits in results:
        for hit in hits:
            result = {
                "document_id": hit.entity.get("document_id"),
                "chunk_id": hit.entity.get("chunk_id"),
                "content": hit.entity.get("content"),
                "score": hit.score,
                "content_type": hit.entity.get("content_type"),
                "page_number": hit.entity.get("page_number"),
                "search_method": "vector"
            }
            formatted_results.append(result)
    
    search_time = time.time() - start_time
    logger.info(f"Vector search completed in {search_time:.3f}s")
    
    return formatted_results, search_time

# 2. BM25 Search
def bm25_search(query, case_id, document_ids=None, top_k=5):
    """
    BM25 search using full-text index
    
    Args:
        query: Search query
        case_id: Case ID to filter by
        document_ids: Optional list of document IDs to filter by
        top_k: Number of results to return
        
    Returns:
        List of search results and search time
    """
    start_time = time.time()
    
    # Prepare filter expression
    expr_parts = []
    
    # Add case_id filter
    expr_parts.append(f"case_id == '{case_id}'")
    
    # Add document_ids filter if provided
    if document_ids and len(document_ids) > 0:
        if len(document_ids) == 1:
            expr_parts.append(f"document_id == '{document_ids[0]}'")
        else:
            doc_list = '", "'.join(document_ids)
            expr_parts.append(f"document_id in ['{doc_list}']")
    
    # Create the full-text expression for content match
    # In Milvus 2.5, we use the MATCH operator for BM25 search
    expr_parts.append(f"TEXT_MATCH(content, '{query}')")
    
    # Combine filter expressions
    expr = " && ".join(expr_parts)
    
    # BM25 search parameters
    search_params = {
        "metric_type": "BM25", 
        "params": {"k1": 1.5, "b": 0.75}  # BM25 parameters
    }
    
    # Execute search
    collection = Collection(name=collection_name)
    results = collection.search(
        data=[query],  # Just pass the query text for BM25
        anns_field="sparse",  # Use sparse field for BM25 search
        param=search_params,
        limit=top_k,
        expr=expr,
        output_fields=["document_id", "chunk_id", "content", "content_type", 
                     "chunk_type", "page_number", "tree_level", "metadata"]
    )
    
    # Format results
    formatted_results = []
    for hits in results:
        for hit in hits:
            result = {
                "document_id": hit.entity.get("document_id"),
                "chunk_id": hit.entity.get("chunk_id"),
                "content": hit.entity.get("content"),
                "score": hit.score,
                "content_type": hit.entity.get("content_type"),
                "page_number": hit.entity.get("page_number"),
                "search_method": "bm25"
            }
            formatted_results.append(result)
    
    search_time = time.time() - start_time
    logger.info(f"BM25 search completed in {search_time:.3f}s")
    
    return formatted_results, search_time

# 3. Hybrid Search (Vector + BM25)
def hybrid_search(query, case_id, document_ids=None, top_k=5, vector_weight=0.5):
    """
    Hybrid search by:
      1) calling vector_search & bm25_search as-is,
      2) min–max normalizing each score list,
      3) weighting and fusing them,
      4) returning top_k results.
    """
    start_time = time.time()

    # 1) run the two methods exactly as before
    vec_hits, _ = vector_search(query, case_id, document_ids, top_k)
    bm_hits,  _ = bm25_search(query, case_id, document_ids, top_k)

    # 2) extract raw scores (avoid empty list issues)
    vec_scores = [h["score"] for h in vec_hits] or [0.0]
    bm_scores  = [h["score"] for h in bm_hits]  or [0.0]

    vmin, vmax = min(vec_scores), max(vec_scores)
    bmin, bmax = min(bm_scores),  max(bm_scores)

    def minmax(s, lo, hi):
        return 0.0 if hi == lo else (s - lo) / (hi - lo)

    # 3) build a merged map keyed by (doc, chunk)
    merged = {}
    for h in vec_hits:
        key = (h["document_id"], h["chunk_id"])
        merged.setdefault(key, {
            "document_id": h["document_id"],
            "chunk_id":    h["chunk_id"],
            "content":     h["content"],
            "vector_n":    0.0,
            "bm25_n":      0.0
        })
        merged[key]["vector_n"] = minmax(h["score"], vmin, vmax)

    for h in bm_hits:
        key = (h["document_id"], h["chunk_id"])
        merged.setdefault(key, {
            "document_id": h["document_id"],
            "chunk_id":    h["chunk_id"],
            "content":     h["content"],
            "vector_n":    0.0,
            "bm25_n":      0.0
        })
        merged[key]["bm25_n"] = minmax(h["score"], bmin, bmax)

    # 4) fuse with weights
    fused = []
    for entry in merged.values():
        score = vector_weight * entry["vector_n"] + (1 - vector_weight) * entry["bm25_n"]
        fused.append({
            **entry,
            "score": score,
            "search_method": "hybrid_minmax_weighted"
        })

    # sort and trim
    fused.sort(key=lambda x: x["score"], reverse=True)
    top = fused[:top_k]

    elapsed = time.time() - start_time
    logger.info(f"Hybrid‐minmax‐weighted completed in {elapsed:.3f}s")
    return top, elapsed



# 3. Hybrid Search (Vector + BM25)
# def hybrid_search(query, case_id, document_ids=None, top_k=5, rrf_k=60):
#     """
#     Hybrid search using Reciprocal Rank Fusion (RRF) for fusion of vector and BM25 results.
#     """
#     start_time = time.time()

#     # Generate embedding for query
#     embedding_model = OllamaEmbeddings(model="llama3.2")
#     query_embedding = embedding_model.embed_query(query)

#     # Build base filter expression
#     expr_parts = [f'case_id == "{case_id}"']
#     if document_ids:
#         if len(document_ids) == 1:
#             expr_parts.append(f'document_id == "{document_ids[0]}"')
#         else:
#             docs = '", "'.join(document_ids)
#             expr_parts.append(f'document_id in ["{docs}"]')
#     expr = " && ".join(expr_parts)

#     # Vector search request
#     vector_request = AnnSearchRequest(
#         data=[query_embedding],
#         anns_field="embedding",
#         param={"metric_type": "COSINE", "params": {"ef": 64}},
#         limit=top_k,
#         expr=expr
#     )

#     # BM25 search request (adds TEXT_MATCH)
#     bm25_request = AnnSearchRequest(
#         data=[query],
#         anns_field="sparse",
#         param={"metric_type": "BM25", "params": {"k1": 1.5, "b": 0.75}},
#         limit=top_k,
#         expr=expr + f" && TEXT_MATCH(content, '{query}')"
#     )

#     # e.g. 30% vector, 70% BM25
#     ranker = WeightedRanker(0.4, 0.6)

#     # Execute hybrid search
#     collection = Collection(name=collection_name)
#     results = collection.hybrid_search(
#         reqs=[vector_request, bm25_request],
#         rerank=ranker,
#         limit=top_k,
#         output_fields=[
#             "document_id","chunk_id","content",
#             "content_type","chunk_type","page_number",
#             "tree_level","metadata"
#         ]
#     )

#     # Format results
#     formatted = []
#     for hits in results:
#         for hit in hits:
#             formatted.append({
#                 "document_id": hit.entity.get("document_id"),
#                 "chunk_id": hit.entity.get("chunk_id"),
#                 "content": hit.entity.get("content"),
#                 "score": hit.score,
#                 "content_type": hit.entity.get("content_type"),
#                 "page_number": hit.entity.get("page_number"),
#                 "search_method": "hybrid_rrf"
#             })

#     logger.info(f"Hybrid RRF search completed in {time.time() - start_time:.3f}s")
#     return formatted, time.time() - start_time



In [None]:
def compare_search_methods(query, case_id, document_ids=None, top_k=5):
    """
    Compare vector, BM25, and hybrid search methods
    
    Args:
        query: Search query
        case_id: Case ID
        document_ids: Optional list of document IDs
        top_k: Number of results to return
        
    Returns:
        Dictionary with results from each method and comparison analysis
    """
    print(f"Query: '{query}'")
    print(f"Case ID: {case_id}")
    print(f"Document IDs: {document_ids or 'All'}")
    print(f"Top K: {top_k}")
    print("-" * 80)
    
    # Execute searches
    try:
        vector_results, vector_time = vector_search(query, case_id, document_ids, top_k)
    except Exception as e:
        logger.error(f"Vector search error: {str(e)}")
        vector_results, vector_time = [], 0
    
    try:
        bm25_results, bm25_time = bm25_search(query, case_id, document_ids, top_k)
    except Exception as e:
        logger.error(f"BM25 search error: {str(e)}")
        bm25_results, bm25_time = [], 0
    
    try:
        hybrid_results, hybrid_time = hybrid_search(query, case_id, document_ids, top_k)
    except Exception as e:
        logger.error(f"Hybrid search error: {str(e)}")
        hybrid_results, hybrid_time = [], 0
    
    # Print timing results
    print("\nSearch Performance:")
    print(f"Vector Search: {vector_time:.3f}s")
    print(f"BM25 Search: {bm25_time:.3f}s")
    print(f"Hybrid Search: {hybrid_time:.3f}s")
    
    # Create a DataFrame for each result type
    if vector_results:
        vector_df = pd.DataFrame(vector_results)
        vector_df = vector_df[['content', 'score', 'document_id', 'search_method']]
        vector_df = vector_df.rename(columns={'score': 'vector_score'})
    else:
        vector_df = pd.DataFrame(columns=['content', 'vector_score', 'document_id', 'search_method'])
    
    if bm25_results:
        bm25_df = pd.DataFrame(bm25_results)
        bm25_df = bm25_df[['content', 'score', 'document_id', 'search_method']]
        bm25_df = bm25_df.rename(columns={'score': 'bm25_score'})
    else:
        bm25_df = pd.DataFrame(columns=['content', 'bm25_score', 'document_id', 'search_method'])
    
    if hybrid_results:
        hybrid_df = pd.DataFrame(hybrid_results)
        hybrid_df = hybrid_df[['content', 'score', 'document_id', 'search_method']]
        hybrid_df = hybrid_df.rename(columns={'score': 'hybrid_score'})
    else:
        hybrid_df = pd.DataFrame(columns=['content', 'hybrid_score', 'document_id', 'search_method'])
    
    # Compare results across methods by content
    all_contents = set(vector_df['content'].tolist() + bm25_df['content'].tolist() + hybrid_df['content'].tolist())
    
    # Create a combined df to analyze overlap
    combined_df = pd.DataFrame({'content': list(all_contents)})
    
    # Merge scores from each method
    if not vector_df.empty:
        combined_df = combined_df.merge(vector_df[['content', 'vector_score']], on='content', how='left')
    else:
        combined_df['vector_score'] = None
        
    if not bm25_df.empty:
        combined_df = combined_df.merge(bm25_df[['content', 'bm25_score']], on='content', how='left')
    else:
        combined_df['bm25_score'] = None
        
    if not hybrid_df.empty:
        combined_df = combined_df.merge(hybrid_df[['content', 'hybrid_score']], on='content', how='left')
    else:
        combined_df['hybrid_score'] = None
    
    # Mark which methods found this content
    combined_df['found_by_vector'] = ~combined_df['vector_score'].isna()
    combined_df['found_by_bm25'] = ~combined_df['bm25_score'].isna()
    combined_df['found_by_hybrid'] = ~combined_df['hybrid_score'].isna()
    
    # Count how many methods found each piece of content
    combined_df['methods_count'] = combined_df[['found_by_vector', 'found_by_bm25', 'found_by_hybrid']].sum(axis=1)
    
    # Sort by hybrid score (if available), then vector score
    sort_cols = []
    if 'hybrid_score' in combined_df.columns and not combined_df['hybrid_score'].isna().all():
        sort_cols.append('hybrid_score')
    if 'vector_score' in combined_df.columns and not combined_df['vector_score'].isna().all():
        sort_cols.append('vector_score')
    if 'bm25_score' in combined_df.columns and not combined_df['bm25_score'].isna().all():
        sort_cols.append('bm25_score')
    
    if sort_cols:
        combined_df = combined_df.sort_values(by=sort_cols, ascending=False)
    
    # Print top results
    pd.set_option('display.max_colwidth', 80)
    print("\nResults Comparison:")
    print(combined_df.head(top_k).to_string())
    
    # Analyze overlap between methods
    print("\nOverlap Analysis:")
    vector_set = set(vector_df['content'] if not vector_df.empty else [])
    bm25_set = set(bm25_df['content'] if not bm25_df.empty else [])
    hybrid_set = set(hybrid_df['content'] if not hybrid_df.empty else [])
    
    print(f"Total unique contents: {len(all_contents)}")
    print(f"Vector-only results: {len(vector_set - bm25_set - hybrid_set)}")
    print(f"BM25-only results: {len(bm25_set - vector_set - hybrid_set)}")
    print(f"Hybrid-only results: {len(hybrid_set - vector_set - bm25_set)}")
    print(f"Vector & BM25 overlap: {len(vector_set & bm25_set - hybrid_set)}")
    print(f"Vector & Hybrid overlap: {len(vector_set & hybrid_set - bm25_set)}")
    print(f"BM25 & Hybrid overlap: {len(bm25_set & hybrid_set - vector_set)}")
    print(f"All methods overlap: {len(vector_set & bm25_set & hybrid_set)}")
    
    # Create a bar chart comparing the number of results from each method
    method_counts = {
        'Vector': len(vector_results),
        'BM25': len(bm25_results),
        'Hybrid': len(hybrid_results)
    }
    
    plt.figure(figsize=(10, 6))
    plt.bar(method_counts.keys(), method_counts.values())
    plt.xlabel('Search Method')
    plt.ylabel('Number of Results')
    plt.title(f'Number of Results by Search Method - Query: "{query}"')
    plt.show()
    
    return {
        "vector": vector_results,
        "bm25": bm25_results,
        "hybrid": hybrid_results,
        "comparison": combined_df,
        "timings": {
            "vector": vector_time,
            "bm25": bm25_time,
            "hybrid": hybrid_time
        }
    }

# Test with different queries
def run_test_queries(case_id="default", document_ids=None):
    """Run test queries and compare search methods"""
    queries = [
        "claimant",
        "nanghloi depot",
    ]
    
    results = {}
    for query in queries:
        print(f"\n{'='*50}\nTesting query: '{query}'\n{'='*50}")
        result = compare_search_methods(query, case_id, document_ids, top_k=5)
        results[query] = result
    
    return results

# Run test queries
test_results = run_test_queries()