# Ablation Study: Semantic vs. Hybrid Embedding Methods



##  MuRP and MuRE Nearest Neighbor Retrieval

In [None]:
import numpy as np
from web.embeddings import load_embedding
from web.evaluate import poincare_distance  # Hyperbolic distance (MuRP)
from scipy.spatial.distance import euclidean  # Euclidean distance (MuRE)

# ----------------------
# 1. Load Embeddings (using MuRP as example)
# ----------------------
embedding_path = "./Mathlib4_embeddings/outputs_cleaned/embeddings/model_dict_poincare_300_my_dataset_cleaned"
model_type = "poincare"  # "poincare" for MuRP, "euclidean" for MuRE

# Load embeddings (returns an Embedding instance)
embeddings = load_embedding(
    embedding_path,
    format="dict",
    normalize=True,   # Required for hyperbolic embeddings (projected inside unit ball)
    lower=True,
    clean_words=False
)

# Extract vocabulary and vector list
vocab = embeddings.words
vectors = np.array([embeddings[word] for word in vocab])

# ----------------------
# 2. Define Retrieval Function (with distance normalization)
# ----------------------
def retrieve_nearest_neighbors(query_word, top_k=5):
    if query_word not in embeddings:
        raise ValueError(f"Query word '{query_word}' not found in the embedding vocabulary.")

    query_vec = embeddings[query_word]
    distances = []
    for word in vocab:
        vec = embeddings[word]
        if model_type == "poincare":
            dist = poincare_distance(query_vec, vec)
        else:
            dist = euclidean(query_vec, vec)
        distances.append((word, dist))

    # Sort distances and exclude the query word itself
    distances_sorted = sorted(distances, key=lambda x: x[1])
    distances_sorted = [item for item in distances_sorted if item[0] != query_word]

    # Take top-k nearest neighbors
    top_neighbors = distances_sorted[:top_k]

    # Normalize distances using min-max scaling
    dists = np.array([dist for _, dist in top_neighbors])
    min_dist, max_dist = dists.min(), dists.max()
    norm_dists = (dists - min_dist) / (max_dist - min_dist + 1e-9)  # +1e-9 to avoid division by zero

    # Return list with normalized distances
    return [(word, dist, norm) for (word, dist), norm in zip(top_neighbors, norm_dists)]

# ----------------------
# 3. Retrieval Example
# ----------------------
if __name__ == "__main__":
    query = "transpose"  # Replace with a valid word from your vocabulary
    top_k = 5

    try:
        neighbors = retrieve_nearest_neighbors(query, top_k=top_k)
        print(f"Top {top_k} most similar words to '{query}' in {model_type} space:")
        for i, (word, dist, norm) in enumerate(neighbors, 1):
            print(f"{i}. {word} (distance: {dist:.4f}, normalized: {norm:.4f})")
    except ValueError as e:
        print(e)


This module implements an ablation study to evaluate the effectiveness of different embedding strategies 
on Mathlib4 concept retrieval. Specifically, it compares:

1. **Only Semantic Text Embeddings** — using Sentence Transformer embeddings for retrieval.
2. **Hybrid Method** — combining semantic embeddings with MuRP-based hyperbolic embeddings for enhanced relational reasoning.

The study measures retrieval performance, coverage of relevant concepts, and quality of extracted modules,
providing insights into the contribution of each component in the hybrid approach.

In [None]:
import os
import json
import numpy as np
import re
import random
import time
from collections import defaultdict
from sentence_transformers import SentenceTransformer
from web.embeddings import load_embedding
from web.evaluate import poincare_distance
from scipy.spatial.distance import euclidean

# Avoid parallelism warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Define regular expression for valid mathematical concepts: only letters, numbers, underscores, minimum length 2
valid_definiendum_pattern = re.compile(r"^[a-z0-9_]{2,}$")  # Updated regex to match lowercase only
math_symbols = set("+-*/=∑∏√∫<>∈∉{}[]()")

def clean_definiendum(definiendum):
    """Clean the definiendum, return a valid form or an empty string"""
    if not definiendum or not isinstance(definiendum, str):
        return ""
    
    # First, replace dots with underscores (correcting dot handling)
    definiendum = definiendum.replace(".", "_")
    
    # Remove mathematical symbols
    definiendum = ''.join(c for c in definiendum if c not in math_symbols)
    
    # Remove non-ASCII characters
    definiendum = ''.join(c for c in definiendum if ord(c) < 128)
    
    # Replace spaces with underscores
    definiendum = definiendum.replace(" ", "_")
    
    # Keep only valid characters (letters, numbers, underscores)
    definiendum = re.sub(r'[^A-Za-z0-9_]', '', definiendum)
    
    # Convert to lowercase
    definiendum = definiendum.lower()
    
    # If empty after cleaning, consider invalid
    if not definiendum:
        return ""
    
    return definiendum

def normalize_entity_name(name):
    """Normalize entity name for matching raw data"""
    return "_".join(name.lower().strip().split())

class SemanticSearch:
    """Pure semantic search model (for ablation experiment comparison)"""
    def __init__(self, json_path, model_name="sentence-transformers/sentence-t5-large", verbose=True):
        self.json_path = json_path
        self.model_name = model_name
        self.verbose = verbose
        self.embedding_model = None
        self.embeddings = None
        self.flattened_concepts = []
        self.original_dataset = None
        
        self.load_data()
        self.load_model()
    
    def log(self, message, level="info"):
        """Log messages based on verbose setting"""
        if self.verbose or level == "error":
            print(message)
    
    def load_data(self):
        """Load data"""
        self.log("Loading semantic search data...")
        
        if not os.path.exists(self.json_path):
            raise FileNotFoundError(f"File not found: {self.json_path}")
        
        with open(self.json_path, "r", encoding="utf-8") as f:
            self.original_dataset = json.load(f)
        
        embeddings_list = []
        
        for module_name, module_content in self.original_dataset.items():
            definitions = module_content.get("definitions", [])
            
            for definition in definitions:
                semantic_analysis = definition.get("semantic_analysis", {})
                concepts = semantic_analysis.get("concepts", [])
                
                for concept in concepts:
                    vec = concept.get("embedding_vector")
                    if vec is not None:
                        embeddings_list.append(vec)
                        concept_copy = concept.copy()
                        if "name" not in concept_copy:
                            concept_copy["name"] = concept_copy.get("concept_name", "Unknown")
                        if "module_path" not in concept_copy:
                            concept_copy["module_path"] = module_name
                        
                        self.flattened_concepts.append(concept_copy)
        
        if embeddings_list:
            self.embeddings = np.array(embeddings_list)
            self.log(f"Loaded {len(self.embeddings)} semantic concept embeddings")
        else:
            raise ValueError("No valid semantic embeddings found")
    
    def load_model(self):
        """Load Sentence Transformer model"""
        self.log(f"Loading semantic model: {self.model_name}")
        try:
            self.embedding_model = SentenceTransformer(self.model_name)
            self.log("Semantic model loaded successfully")
        except Exception as e:
            self.log(f"Failed to load semantic model: {e}", "error")
            self.log("Falling back to default model...")
            self.embedding_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
    
    def search_similar_concepts(self, query, top_k=10, min_score=0.1):
        """Execute semantic search"""
        if self.verbose:
            self.log(f"Semantic search for: '{query[:50]}...'")
        
        if self.embeddings is None or len(self.embeddings) == 0:
            self.log("No embeddings available", "error")
            return []
        
        try:
            # Generate query embedding
            query_vec = self.embedding_model.encode([query])
            if isinstance(query_vec, list):
                query_vec = np.array(query_vec)
            if query_vec.ndim > 1:
                query_vec = query_vec[0]
            
            # Normalize vectors
            query_vec = query_vec / (np.linalg.norm(query_vec) + 1e-8)
            norm_embeddings = self.embeddings / (np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-8)
            
            # Calculate cosine similarity
            cosine_scores = np.dot(norm_embeddings, query_vec)
            
            # Get indices of all results (sorted by similarity descending)
            all_indices = np.argsort(cosine_scores)[::-1]
            
        except Exception as e:
            self.log(f"Semantic search failed: {e}", "error")
            return []
        
        # Collect results (with deduplication)
        results = []
        seen_concepts = set()  # Track seen concept names
        
        # Iterate through all results until desired count is reached or all results are processed
        for i in all_indices:
            if len(results) >= top_k:
                break
                
            if i < len(self.flattened_concepts):
                concept = self.flattened_concepts[i]
                score = cosine_scores[i]
                
                # Skip scores below threshold
                if score < min_score:
                    continue
                
                # Get and clean concept name
                concept_name = concept.get('name', '')
                cleaned_name = clean_definiendum(concept_name)
                
                # Check for duplicates
                if cleaned_name and cleaned_name not in seen_concepts:
                    results.append({
                        'concept': concept,
                        'semantic_score': float(score),
                        'index': int(i)
                    })
                    seen_concepts.add(cleaned_name)
        
        if self.verbose:
            self.log(f"   {len(results)} unique results above threshold {min_score}")
        return results
    
    def semantic_search(self, query, top_k=10, min_score=0.1, return_top_k=None):
        """Execute semantic search and return related modules"""
        results = self.search_similar_concepts(query, top_k, min_score)
        
        # Extract all related modules
        import_modules = set()
        for result in results:
            module = result['concept'].get('module_path', '')
            if module:
                import_modules.add(module)
        
        import_modules = sorted(import_modules)
        
        # Control number of returned modules
        if return_top_k is not None:
            import_modules = import_modules[:return_top_k]
        
        return import_modules

class HybridSemanticSearch:
    """Hybrid search model (semantic search + MuRP)"""
    def __init__(self, json_path, murp_embedding_path, 
                 semantic_model="sentence-transformers/sentence-t5-large",
                 murp_model_type="poincare",
                 verbose=True):
        """
        Hybrid retrieval system: semantic search first, then MuRP for secondary retrieval
        
        Args:
            json_path: JSON file path (contains concepts and embeddings)
            murp_embedding_path: MuRP embedding file path
            semantic_model: Sentence Transformer model name
            murp_model_type: MuRP model type ("poincare" or "euclidean")
            verbose: Whether to show detailed logs
        """
        self.json_path = json_path
        self.murp_embedding_path = murp_embedding_path
        self.semantic_model_name = semantic_model
        self.murp_model_type = murp_model_type
        self.verbose = verbose
        
        # Initialize components
        self.semantic_model = None
        self.murp_embeddings = None
        self.semantic_embeddings = None
        self.flattened_concepts = []
        self.murp_vocab = []
        self.original_dataset = None  # Store original dataset
        self.evaluation_data = None   # Store evaluation data
        
        # Load data
        self.load_semantic_data()
        self.load_semantic_model()
        self.load_murp_embeddings()
    
    def log(self, message, level="info"):
        """Log messages based on verbose setting"""
        if self.verbose or level == "error":
            print(message)
    
    def load_semantic_data(self):
        """Load semantic search data"""
        self.log("Loading semantic search data...")
        
        if not os.path.exists(self.json_path):
            raise FileNotFoundError(f"File not found: {self.json_path}")
        
        with open(self.json_path, "r", encoding="utf-8") as f:
            self.original_dataset = json.load(f)
        
        embeddings_list = []
        
        for module_name, module_content in self.original_dataset.items():
            definitions = module_content.get("definitions", [])
            
            for definition in definitions:
                semantic_analysis = definition.get("semantic_analysis", {})
                concepts = semantic_analysis.get("concepts", [])
                
                for concept in concepts:
                    vec = concept.get("embedding_vector")
                    if vec is not None:
                        embeddings_list.append(vec)
                        concept_copy = concept.copy()
                        if "name" not in concept_copy:
                            concept_copy["name"] = concept_copy.get("concept_name", "Unknown")
                        if "module_path" not in concept_copy:
                            concept_copy["module_path"] = module_name
                        
                        self.flattened_concepts.append(concept_copy)
        
        if embeddings_list:
            self.semantic_embeddings = np.array(embeddings_list)
            self.log(f"Loaded {len(self.semantic_embeddings)} semantic concept embeddings")
        else:
            raise ValueError("No valid semantic embeddings found")
    
    def load_semantic_model(self):
        """Load Sentence Transformer model"""
        self.log(f"Loading semantic model: {self.semantic_model_name}")
        try:
            self.semantic_model = SentenceTransformer(self.semantic_model_name)
            self.log("Semantic model loaded successfully")
        except Exception as e:
            self.log(f"Failed to load semantic model: {e}", "error")
            raise
    
    def load_murp_embeddings(self):
        """Load MuRP embeddings"""
        self.log(f"Loading MuRP embeddings from: {self.murp_embedding_path}")
        self.log(f"   Model type: {self.murp_model_type}")
        
        try:
            self.murp_embeddings = load_embedding(
                self.murp_embedding_path,
                format="dict",
                normalize=True,   # Required for hyperbolic embeddings
                lower=True,        # Ensure loaded embeddings are lowercase
                clean_words=False
            )
            
            self.murp_vocab = self.murp_embeddings.words
            self.log(f"Loaded MuRP embeddings with {len(self.murp_vocab)} vocabulary items")
            
        except Exception as e:
            self.log(f"Failed to load MuRP embeddings: {e}", "error")
            raise
    
    def semantic_search(self, query, top_k=20, min_score=0.0):
        """Phase 1: Semantic search (with deduplication)"""
        if self.verbose:
            self.log(f"Phase 1: Semantic search for: '{query[:50]}...'")
        
        if self.semantic_embeddings is None or len(self.semantic_embeddings) == 0:
            self.log("No semantic embeddings available", "error")
            return []
        
        try:
            # Generate query embedding
            query_vec = self.semantic_model.encode([query])
            if isinstance(query_vec, list):
                query_vec = np.array(query_vec)
            if query_vec.ndim > 1:
                query_vec = query_vec[0]
            
            # Normalize vectors
            query_vec = query_vec / (np.linalg.norm(query_vec) + 1e-8)
            norm_embeddings = self.semantic_embeddings / (np.linalg.norm(self.semantic_embeddings, axis=1, keepdims=True) + 1e-8)
            
            # Calculate cosine similarity
            cosine_scores = np.dot(norm_embeddings, query_vec)
            
            # Get indices of all results (sorted by similarity descending)
            all_indices = np.argsort(cosine_scores)[::-1]
            
        except Exception as e:
            self.log(f"Semantic search failed: {e}", "error")
            return []
        
        # Collect results (with deduplication)
        results = []
        seen_concepts = set()  # Track seen concept names
        
        # Iterate through all results until desired count is reached or all results are processed
        for i in all_indices:
            if len(results) >= top_k:
                break
                
            if i < len(self.flattened_concepts):
                concept = self.flattened_concepts[i]
                score = cosine_scores[i]
                
                # Skip scores below threshold
                if score < min_score:
                    continue
                
                # Get and clean concept name
                concept_name = concept.get('name', '')
                cleaned_name = clean_definiendum(concept_name)
                
                # Check for duplicates
                if cleaned_name and cleaned_name not in seen_concepts:
                    results.append({
                        'concept': concept,
                        'semantic_score': float(score),
                        'index': int(i)
                    })
                    seen_concepts.add(cleaned_name)
        
        if self.verbose:
            self.log(f"   {len(results)} unique results above threshold {min_score}")
        return results
    
    def murp_search(self, concept_names, top_k=10):
        """Phase 2: MuRP retrieval"""
        if self.verbose:
            self.log(f"Phase 2: MuRP search for {len(concept_names)} concepts")
        
        if not self.murp_embeddings:
            self.log("No MuRP embeddings available", "error")
            return {}
        
        murp_results = {}
        found_concepts = []
        missing_concepts = []
        
        for concept_name in concept_names:
            # Since all names are lowercase, we only need to try possible variants
            possible_names = [
                concept_name,
                concept_name.replace('.', '_'),  # Ensure consistent dot handling
                concept_name.replace('_', '.'),
                concept_name.split('.')[-1] if '.' in concept_name else concept_name
            ]
            
            query_name = None
            for name in possible_names:
                if name in self.murp_embeddings:
                    query_name = name
                    break
            
            if query_name is None:
                missing_concepts.append(concept_name)
                continue
            
            found_concepts.append((concept_name, query_name))
            
            try:
                neighbors = self.retrieve_murp_neighbors(query_name, top_k)
                murp_results[concept_name] = {
                    'query_name_used': query_name,
                    'neighbors': neighbors
                }
            except Exception as e:
                self.log(f"MuRP search failed for {concept_name}: {e}", "error")
        
        return murp_results
    
    def retrieve_murp_neighbors(self, query_word, top_k=10):
        """MuRP neighbor retrieval"""
        if query_word not in self.murp_embeddings:
            raise ValueError(f"Query word '{query_word}' not found in MuRP vocabulary")
        
        query_vec = self.murp_embeddings[query_word]
        distances = []
        
        for word in self.murp_vocab:
            vec = self.murp_embeddings[word]
            if self.murp_model_type == "poincare":
                dist = poincare_distance(query_vec, vec)
            else:
                dist = euclidean(query_vec, vec)
            distances.append((word, dist))
        
        # Sort distances and exclude the query word itself
        distances_sorted = sorted(distances, key=lambda x: x[1])
        distances_sorted = [item for item in distances_sorted if item[0] != query_word]
        
        # Take top-k nearest neighbors
        top_neighbors = distances_sorted[:top_k]
        
        # Normalize distances
        if top_neighbors:
            dists = np.array([dist for _, dist in top_neighbors])
            min_dist, max_dist = dists.min(), dists.max()
            if max_dist > min_dist:
                norm_dists = (dists - min_dist) / (max_dist - min_dist)
            else:
                norm_dists = np.zeros_like(dists)
            
            return [(word, dist, norm) for (word, dist), norm in zip(top_neighbors, norm_dists)]
        
        return []
    
    def hybrid_search(self, query, semantic_top_k=20, murp_top_k=10, semantic_min_score=0.1, return_top_k=10):
        """Hybrid retrieval: semantic search + MuRP secondary retrieval"""
        if self.verbose:
            self.log(f"Starting hybrid search for: '{query[:50]}...'")
        
        # Phase 1: Semantic search (with deduplication)
        semantic_results = self.semantic_search(query, semantic_top_k, semantic_min_score)
        
        if not semantic_results:
            return []
        
        # Extract and clean concept names for MuRP retrieval
        raw_concept_names = [result['concept'].get('name', '') for result in semantic_results]
        raw_concept_names = [name for name in raw_concept_names if name and name != 'Unknown']
        
        # Apply cleaning function (includes dot correction)
        cleaned_concept_names = []
        for name in raw_concept_names:
            cleaned = clean_definiendum(name)
            if cleaned:
                cleaned_concept_names.append(cleaned)
        
        if not cleaned_concept_names:
            return []
        
        # Phase 2: MuRP retrieval (using cleaned lowercase names)
        murp_results = self.murp_search(cleaned_concept_names, murp_top_k)
        
        # Create a dictionary to store modules and their relevance scores
        module_scores = defaultdict(float)
        
        # 1. Assign scores to modules from semantic search results
        for result in semantic_results:
            module = result['concept'].get('module_path', '')
            score = result.get('semantic_score', 0)
            if module:
                module_scores[module] += score  # Accumulate scores
        
        # 2. Assign scores to modules from MuRP results (based on distance)
        for concept_name, murp_data in murp_results.items():
            for neighbor, _, norm_dist in murp_data['neighbors']:
                # Find the module to which the neighbor belongs
                neighbor_module = None
                for module_name, module_data in self.original_dataset.items():
                    for definition in module_data.get("definitions", []):
                        def_name = definition.get("name", "")
                        if def_name and clean_definiendum(def_name) == neighbor:
                            neighbor_module = module_name
                            break
                    if neighbor_module:
                        break
                
                if neighbor_module:
                    # Smaller distance means higher score (1 - normalized distance)
                    module_scores[neighbor_module] += (1 - norm_dist)
        
        # 3. Sort by score and select top return_top_k
        sorted_modules = sorted(module_scores.items(), key=lambda x: x[1], reverse=True)
        top_modules = [module for module, score in sorted_modules[:return_top_k]]
        
        return top_modules

class ModelEvaluator:
    """Model evaluator for comparing pure semantic search and hybrid search"""
    def __init__(self, json_path, murp_embedding_path=None):
        self.json_path = json_path
        self.murp_embedding_path = murp_embedding_path
        self.evaluation_data = None
        
        # Prepare evaluation data
        self.prepare_evaluation_data()
    
    def prepare_evaluation_data(self, sample_size=1000):
        """
        Prepare evaluation data: randomly sample definitions from JSON file
        Return format: [
            {
                "query": "informal description",
                "target_module": "definition module",
                "definition_name": "definition name"
            }
        ]
        """
        print("Preparing evaluation data...")
        
        # Load original dataset
        if not os.path.exists(self.json_path):
            raise FileNotFoundError(f"File not found: {self.json_path}")
        
        with open(self.json_path, "r", encoding="utf-8") as f:
            original_dataset = json.load(f)
        
        evaluation_samples = []
        
        # Collect all valid definitions
        all_definitions = []
        for module_name, module_data in original_dataset.items():
            for definition in module_data.get("definitions", []):
                if not isinstance(definition, dict):
                    continue
                
                informal = definition.get("semantic_analysis", {}).get("informal", "")
                def_name = definition.get("name", "")
                
                # Ensure valid non-empty informal description and definition name
                if informal and def_name:
                    all_definitions.append({
                        "module": module_name,
                        "name": def_name,
                        "query": informal
                    })
        
        print(f"   Found {len(all_definitions)} valid definitions with informal descriptions")
        
        # Random sampling
        if len(all_definitions) < sample_size:
            print(f"Warning: Only {len(all_definitions)} definitions available, using all")
            sample_size = len(all_definitions)
        
        random.shuffle(all_definitions)
        self.evaluation_data = all_definitions[:sample_size]
        
        print(f"Prepared {len(self.evaluation_data)} evaluation samples")
        return self.evaluation_data
    
    def evaluate_model(self, model, model_name, semantic_top_k=10, murp_top_k=5, semantic_min_score=0.1, return_top_k=10):
        """
        Evaluate model performance
        """
        if not self.evaluation_data:
            self.prepare_evaluation_data()
        
        sample_size = len(self.evaluation_data)
        
        print(f"Evaluating {model_name} with {sample_size} samples...")
        print("=" * 80)
        
        # Evaluation metrics
        results = {
            "model_name": model_name,
            "total_samples": sample_size,
            "correct_predictions": 0,
            "incorrect_predictions": 0,
            "module_recall": 0.0,
            "query_times": [],
            "module_coverage": defaultdict(int),
            "detailed_results": []
        }
        
        # Progress counter
        progress_interval = max(1, sample_size // 10)  # Show progress every 10%
        
        # Evaluate each sample
        for i, sample in enumerate(self.evaluation_data):
            if i % progress_interval == 0:
                print(f"   Processing sample {i+1}/{sample_size} ({((i+1)/sample_size)*100:.1f}%)")
            
            start_time = time.time()
            
            try:
                # Execute search
                if model_name == "Semantic Search":
                    retrieved_modules = model.semantic_search(
                        sample['query'], 
                        top_k=semantic_top_k,
                        min_score=semantic_min_score,
                        return_top_k=return_top_k
                    )
                else:  # Hybrid Search
                    retrieved_modules = model.hybrid_search(
                        query=sample['query'],
                        semantic_top_k=semantic_top_k,
                        murp_top_k=murp_top_k,
                        semantic_min_score=semantic_min_score,
                        return_top_k=return_top_k
                    )
                
                elapsed_time = time.time() - start_time
                results["query_times"].append(elapsed_time)
                
                # Get retrieved modules
                retrieved_modules_set = set(retrieved_modules)
                target_module = sample["module"]
                
                # Check if target module is in retrieval results
                is_correct = target_module in retrieved_modules_set
                
                # Update results
                if is_correct:
                    results["correct_predictions"] += 1
                    results["module_coverage"][target_module] += 1
                else:
                    results["incorrect_predictions"] += 1
                
                # Store detailed results
                detailed_result = {
                    "sample_id": i,
                    "definition_name": sample["name"],
                    "target_module": target_module,
                    "query": sample["query"],
                    "retrieved_modules": list(retrieved_modules_set),
                    "is_correct": is_correct,
                    "time_taken": elapsed_time
                }
                results["detailed_results"].append(detailed_result)
                
            except Exception as e:
                elapsed_time = time.time() - start_time
                results["query_times"].append(elapsed_time)
                results["incorrect_predictions"] += 1
                print(f"Error during search for sample {i+1}: {e}")
                detailed_result = {
                    "sample_id": i,
                    "definition_name": sample["name"],
                    "target_module": target_module,
                    "query": sample["query"],
                    "error": str(e),
                    "is_correct": False,
                    "time_taken": elapsed_time
                }
                results["detailed_results"].append(detailed_result)
        
        # Calculate recall
        results["module_recall"] = results["correct_predictions"] / sample_size if sample_size > 0 else 0
        
        # Calculate average query time
        results["avg_query_time"] = sum(results["query_times"]) / sample_size if sample_size > 0 else 0
        
        # Calculate module coverage
        total_correct = results["correct_predictions"]
        if total_correct > 0:
            for module, count in results["module_coverage"].items():
                results["module_coverage"][module] = count / total_correct
        
        print("=" * 80)
        print(f"{model_name.upper()} EVALUATION SUMMARY")
        print("=" * 80)
        print(f"Total Samples: {results['total_samples']}")
        print(f"Correct Predictions: {results['correct_predictions']}")
        print(f"Incorrect Predictions: {results['incorrect_predictions']}")
        print(f"Module Recall: {results['module_recall']:.4f}")
        print(f"Average Query Time: {results['avg_query_time']:.2f} seconds")
        
        # Print top modules by coverage
        if results["module_coverage"]:
            print("Top Modules by Coverage:")
            sorted_modules = sorted(results["module_coverage"].items(), key=lambda x: x[1], reverse=True)[:5]
            for module, coverage in sorted_modules:
                print(f"  {module}: {coverage:.2f}")
        
        return results
    
    def save_evaluation_results(self, results, output_path="model_evaluation.json"):
        """Save evaluation results to JSON file"""
        print(f"Saving evaluation results to {output_path}")
        
        # Prepare data to save
        save_data = {
            "evaluation_summary": {
                "model_name": results["model_name"],
                "total_samples": results["total_samples"],
                "correct_predictions": results["correct_predictions"],
                "incorrect_predictions": results["incorrect_predictions"],
                "module_recall": results["module_recall"],
                "avg_query_time": results["avg_query_time"],
                "module_coverage": dict(results["module_coverage"])
            },
            "detailed_results": results["detailed_results"]
        }
        
        try:
            with open(output_path, "w", encoding="utf-8") as f:
                json.dump(save_data, f, indent=2, ensure_ascii=False)
            print("Evaluation results saved successfully")
        except Exception as e:
            print(f"Failed to save evaluation results: {e}")
    
    def compare_models(self, semantic_top_k=10, murp_top_k=5, semantic_min_score=0.1, return_top_k=10):
        """Compare performance of pure semantic search and hybrid search"""
        # Initialize models
        semantic_model = SemanticSearch(
            self.json_path, 
            model_name="sentence-transformers/sentence-t5-large",
            verbose=False
        )
        
        hybrid_model = HybridSemanticSearch(
            self.json_path,
            self.murp_embedding_path,
            semantic_model="sentence-transformers/sentence-t5-large",
            murp_model_type="poincare",
            verbose=False
        )
        
        # Evaluate pure semantic search
        semantic_results = self.evaluate_model(
            semantic_model,
            "Semantic Search",
            semantic_top_k=semantic_top_k,
            semantic_min_score=semantic_min_score,
            return_top_k=return_top_k
        )
        
        # Evaluate hybrid search
        hybrid_results = self.evaluate_model(
            hybrid_model,
            "Hybrid Search",
            semantic_top_k=semantic_top_k,
            murp_top_k=murp_top_k,
            semantic_min_score=semantic_min_score,
            return_top_k=return_top_k
        )
        
        # Save results
        self.save_evaluation_results(semantic_results, "semantic_search_evaluation.json")
        self.save_evaluation_results(hybrid_results, "hybrid_search_evaluation.json")
        
        # Print comparison results
        print("=" * 80)
        print("ABLATION STUDY RESULTS")
        print("=" * 80)
        print(f"{'Model':<20} | {'Recall':<10} | {'Avg Time (s)':<12} | {'Improvement':<12}")
        print("-" * 60)
        
        recall_diff = hybrid_results["module_recall"] - semantic_results["module_recall"]
        time_diff = hybrid_results["avg_query_time"] - semantic_results["avg_query_time"]
        
        print(f"{'Semantic Search':<20} | {semantic_results['module_recall']:.4f}    | {semantic_results['avg_query_time']:.4f}      | {'-':<12}")
        print(f"{'Hybrid Search':<20} | {hybrid_results['module_recall']:.4f}    | {hybrid_results['avg_query_time']:.4f}      | +{recall_diff:.4f}")
        print("-" * 60)
        print(f"Recall Improvement: {recall_diff:.4f} ({recall_diff/semantic_results['module_recall']:.2%})")
        print(f"Time Cost Increase: {time_diff:.4f} seconds ({time_diff/semantic_results['avg_query_time']:.2%})")
        
        return semantic_results, hybrid_results


# ----------------------
# Usage Example
# ----------------------
def main():
    # Configure paths
    json_path = "./Informalisation_and_Mathematical_DSRL/merged_with_embeddings_and_triples.json"
    murp_embedding_path = "./Mathlib4_embeddings/outputs_cleaned/embeddings/model_dict_poincare_300_my_dataset_cleaned"
    
    try:
        # Initialize evaluator
        print("Initializing Model Evaluator...")
        evaluator = ModelEvaluator(json_path, murp_embedding_path)
        
        # Execute ablation study comparison
        print("Starting ablation study...")
        semantic_results, hybrid_results = evaluator.compare_models(
            semantic_top_k=10,
            murp_top_k=5,
            semantic_min_score=0.1,
            return_top_k=10  # Control returning 10 modules
        )
        
        # Execute single query example
        print("=" * 80)
        print("Running single query example with Hybrid Search...")
        hybrid_model = HybridSemanticSearch(
            json_path,
            murp_embedding_path,
            semantic_model="sentence-transformers/sentence-t5-large",
            murp_model_type="poincare",
            verbose=True
        )
        
        query = "For any real matrix A: Matrix m × n, if the columns of A are pairwise orthogonal, then the matrix Aᵀ * A is a diagonal matrix."
        retrieved_modules = hybrid_model.hybrid_search(
            query=query,
            semantic_top_k=10,
            murp_top_k=5,
            semantic_min_score=0.1,
            return_top_k=10  # Ensure returning 10 modules
        )
        
        print(f"Retrieved {len(retrieved_modules)} modules:")
        for i, module in enumerate(retrieved_modules, 1):
            print(f" {i}. {module}")
        
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()