# Autoformalisation Based on Large Language Models via Retrieval-Augmented Generation

This notebook demonstrates the process of autoformalisation based on large language models (LLMs) combined with retrieval-augmented generation (RAG) on Mathlib4. The objective is to explore how unstructured natural language input can be transformed into structured, formalised Lean4 representations.


In [1]:
import json
import numpy as np
from sentence_transformers import SentenceTransformer
import openai
import os
import torch
import gensim.downloader as api

## Preparation: Initialize OpenAI API

In [None]:
OPENAI_API_KEY = ""  # Replace with your actual OpenAI API key
openai.api_key = OPENAI_API_KEY

## 1.  Text Embedding Model Retrieval
### 1.1.  Model Selection

In [3]:
# Define the currently selected model
SELECTED_MODEL_KEY = "sentence-t5-large"  # Available model

# Dictionary of supported models and their corresponding identifiers
MODELS = {
    "Glove": "glove-wiki-gigaword-100",
    "Word2Vec": "word2vec-google-news-300",
    "bert-base": "bert-base-uncased",
    "bert-large": "bert-large-uncased",
    "defsent-bert": "princeton-nlp/defsent-bert",
    "defsent-roberta": "princeton-nlp/defsent-roberta",
    "distilroberta-vl": "sentence-transformers/distilroberta-base-v1",
    "mpnet-base-v2": "sentence-transformers/all-mpnet-base-v2",
    "sentence-t5-large": "sentence-transformers/sentence-t5-large",
    "HIT": "HIT/some-hit-model"
}


### 1.2 Load Embedding Models and Preparations

In [5]:
# -------------------------------
# Load Embedding Models
# -------------------------------
def load_embedding_model(model_key):
    if model_key not in MODELS:
        raise ValueError(f"Unknown model key: {model_key}")
    
    model_name = MODELS[model_key]
    
    if model_key in ["Word2Vec", "Glove"]:
        class WordVectorWrapper:
            def __init__(self):
                print(f"Loading word vector model: {model_name}")
                self.wv = api.load(model_name)
                self.dim = self.wv.vector_size

            def encode(self, texts):
                if isinstance(texts, str):
                    texts = [texts]
                embeddings = []
                for text in texts:
                    words = text.lower().split()
                    word_vectors = [self.wv[word] for word in words if word in self.wv]
                    if word_vectors:
                        embeddings.append(np.mean(word_vectors, axis=0))
                    else:
                        embeddings.append(np.zeros(self.dim))
                return np.array(embeddings)
        return WordVectorWrapper()

    elif model_key.startswith("HIT"):
        from transformers import AutoTokenizer, AutoModel
        class HierarchyTransformerWrapper:
            def __init__(self):
                print(f"Loading HIT model: {model_name}")
                self.tokenizer = AutoTokenizer.from_pretrained(model_name)
                self.model = AutoModel.from_pretrained(model_name)
                self.model.eval()
            def encode(self, texts):
                if isinstance(texts, str):
                    texts = [texts]
                inputs = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True)
                with torch.no_grad():
                    outputs = self.model(**inputs)
                return outputs.last_hidden_state[:,0,:].numpy()
        return HierarchyTransformerWrapper()

    elif model_key in ["bert-base", "bert-large", "defsent-bert", "defsent-roberta"]:
        from transformers import AutoTokenizer, AutoModel
        class BertWrapper:
            def __init__(self):
                print(f"Loading BERT model: {model_name}")
                self.tokenizer = AutoTokenizer.from_pretrained(model_name)
                self.model = AutoModel.from_pretrained(model_name)
                self.model.eval()
            def encode(self, texts):
                if isinstance(texts, str):
                    texts = [texts]
                embeddings = []
                for text in texts:
                    inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True)
                    with torch.no_grad():
                        outputs = self.model(**inputs)
                    last_hidden = outputs.last_hidden_state
                    attention_mask = inputs['attention_mask'].unsqueeze(-1)
                    masked_hidden = last_hidden * attention_mask
                    pooled = masked_hidden.sum(dim=1) / attention_mask.sum(dim=1)
                    embeddings.append(pooled.squeeze().numpy())
                return np.array(embeddings)
        return BertWrapper()

    else:
        print(f"Loading Sentence Transformer model: {model_name}")
        return SentenceTransformer(model_name)

# -------------------------------
# Initialize Embedding Model
# -------------------------------
try:
    embedding_model = load_embedding_model(SELECTED_MODEL_KEY)
except Exception as e:
    print(f"Model loading failed: {e}")
    SELECTED_MODEL_KEY = "sentence-t5-large"
    print(f"Fallback to default model: {SELECTED_MODEL_KEY}")
    embedding_model = load_embedding_model(SELECTED_MODEL_KEY)

# -------------------------------
# Load JSON File and Extract Embeddings
# -------------------------------
json_path = "./Informalisation_and_Mathematical_DSRL/merged_with_embeddings_and_triples.json"
if not os.path.exists(json_path):
    raise FileNotFoundError(f"Embedding file not found: {json_path}")

with open(json_path, "r", encoding="utf-8") as f:
    concept_data_raw = json.load(f)

flattened_concepts = []
embeddings = []

# Iterate over JSON entries and extract embedding vectors
for module_content in concept_data_raw.values():
    for definition in module_content.get("definitions", []):
        for concept in definition.get("semantic_analysis", {}).get("concepts", []):
            vec = concept.get("embedding_vector")
            if vec is not None:
                embeddings.append(vec)
                flattened_concepts.append(concept)

embeddings = np.array(embeddings)
print(f"Extracted {len(embeddings)} concept embedding vectors")

# -------------------------------
# Search for Similar Concepts
# -------------------------------
def search_similar_concepts(query, top_k=5):
    query_vec = embedding_model.encode(query)
    if isinstance(query_vec, list):
        query_vec = np.array(query_vec)
    if query_vec.ndim > 1:
        query_vec = query_vec[0]
    query_vec = query_vec / np.linalg.norm(query_vec)
    norm_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    cosine_scores = np.dot(norm_embeddings, query_vec)
    top_indices = np.argsort(cosine_scores)[-top_k:][::-1]
    results = []
    for i in top_indices:
        concept = flattened_concepts[i]
        results.append(concept)
    return results


### 1.3  Retrieval of Relevant Mathlib4 Module Imports based on Natural Language Queries

In [6]:
import os
import json
import numpy as np
from sentence_transformers import SentenceTransformer

# Avoid parallelism warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

class SemanticSearch:
    def __init__(self, json_path, model_name="sentence-transformers/sentence-t5-large"):
        self.json_path = json_path
        self.model_name = model_name
        self.embedding_model = None
        self.embeddings = None
        self.flattened_concepts = []
        
        self.load_data()
        self.load_model()
    
    def load_data(self):
        """Load and extract embeddings from JSON file"""
        print("Loading data...")
        
        if not os.path.exists(self.json_path):
            raise FileNotFoundError(f"File not found: {self.json_path}")
        
        with open(self.json_path, "r", encoding="utf-8") as f:
            concept_data_raw = json.load(f)
        
        embeddings_list = []
        
        for module_name, module_content in concept_data_raw.items():
            definitions = module_content.get("definitions", [])
            
            for definition in definitions:
                semantic_analysis = definition.get("semantic_analysis", {})
                concepts = semantic_analysis.get("concepts", [])
                
                for concept in concepts:
                    vec = concept.get("embedding_vector")
                    if vec is not None:
                        embeddings_list.append(vec)
                        # Ensure concept has required fields with fallbacks
                        concept_copy = concept.copy()
                        if "name" not in concept_copy:
                            concept_copy["name"] = concept_copy.get("concept_name", "Unknown")
                        if "module_path" not in concept_copy:
                            concept_copy["module_path"] = module_name
                        
                        self.flattened_concepts.append(concept_copy)
        
        if embeddings_list:
            self.embeddings = np.array(embeddings_list)
            print(f"Loaded {len(self.embeddings)} concept embeddings")
            print(f"Embedding dimension: {self.embeddings.shape[1]}")
        else:
            raise ValueError("No valid embeddings found in the JSON file")
    
    def load_model(self):
        """Load the sentence transformer model"""
        print(f"Loading model: {self.model_name}")
        try:
            self.embedding_model = SentenceTransformer(self.model_name)
            print("Model loaded successfully")
        except Exception as e:
            print(f"Failed to load model: {e}")
            print("Falling back to default model...")
            self.embedding_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
    
    def search_similar_concepts(self, query, top_k=5, min_score=0.0):
        """
        Search for similar concepts using cosine similarity
        
        Args:
            query (str): Search query
            top_k (int): Number of top results to return
            min_score (float): Minimum similarity score threshold
        
        Returns:
            list: List of similar concepts with similarity scores
        """
        if self.embeddings is None or len(self.embeddings) == 0:
            print("No embeddings available")
            return []
        
        print(f"Searching for: '{query}'")
        
        try:
            # Generate query embedding
            query_vec = self.embedding_model.encode([query])
            if isinstance(query_vec, list):
                query_vec = np.array(query_vec)
            if query_vec.ndim > 1:
                query_vec = query_vec[0]
            
            print(f"Query embedding shape: {query_vec.shape}")
            
        except Exception as e:
            print(f"Failed to encode query: {e}")
            return []
        
        try:
            # Normalize vectors
            query_vec = query_vec / (np.linalg.norm(query_vec) + 1e-8)
            norm_embeddings = self.embeddings / (np.linalg.norm(self.embeddings, axis=1, keepdims=True) + 1e-8)
            
            # Calculate cosine similarity
            cosine_scores = np.dot(norm_embeddings, query_vec)
            
            # Get top_k indices
            top_indices = np.argsort(cosine_scores)[-top_k:][::-1]
            
            print(f"Top {min(top_k, len(top_indices))} similarity scores: {cosine_scores[top_indices]}")
            
        except Exception as e:
            print(f"Failed to calculate similarity: {e}")
            return []
        
        # Collect results
        results = []
        for i in top_indices:
            if i < len(self.flattened_concepts):
                concept = self.flattened_concepts[i]
                score = cosine_scores[i]
                
                # Apply minimum score threshold
                if score >= min_score:
                    results.append({
                        'concept': concept,
                        'similarity_score': float(score),
                        'index': int(i)
                    })
        
        print(f"Found {len(results)} results above threshold {min_score}")
        return results
    
    def print_results(self, results, show_details=False):
        """Pretty print search results"""
        if not results:
            print("No results found")
            return
        
        print(f"\nSearch Results ({len(results)} found):")
        print("=" * 60)
        
        seen = set()
        unique_results = []
        
        for result in results:
            concept = result['concept']
            score = result['similarity_score']
            
            name = concept.get('name', 'Unknown')
            module = concept.get('module_path', 'Unknown')
            
            # Deduplicate by name and module
            key = (name, module)
            if key not in seen:
                seen.add(key)
                unique_results.append((name, module, score, concept))
        
        for i, (name, module, score, concept) in enumerate(unique_results, 1):
            print(f"{i}. Name: {name}")
            print(f"   Module: {module}")
            print(f"   Similarity: {score:.4f}")
            
            if show_details:
                # Show additional fields if available
                for key, value in concept.items():
                    if key not in ['name', 'module_path', 'embedding_vector'] and value:
                        print(f"   {key}: {str(value)[:100]}...")
            print()


# -------------------------------
# Usage Example
# -------------------------------
def main():
    json_path = "./Informalisation_and_Mathematical_DSRL/merged_with_embeddings_and_triples.json"

    try:
        # Initialize search system
        search_system = SemanticSearch(json_path)
        
        # Example queries
        queries = [
            "For any real matrix A: Matrix m × n, if the columns of A are pairwise orthogonal, then the matrix Aᵀ * A is a diagonal matrix.",
            # "orthogonal matrix",
            # "diagonal matrix",
            # "matrix transpose"
        ]
        
        for query in queries:
            print(f"\n{'='*80}")
            results = search_system.search_similar_concepts(
                query, 
                top_k=10, 
                min_score=0.1  # Only show results with similarity > 0.1
            )
            search_system.print_results(results, show_details=False)
            
    except Exception as e:
        print(f"Error in main: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()


## 2. Multi-Relational Hyperbolic (Poincaré) Embeddings

This section introduces multi-relational hyperbolic embeddings in the Poincaré ball model, which enable structured representation and semantic querying of mathematical knowledge. By integrating textual definitions with hierarchical relationships, hyperbolic space naturally captures tree-like structures and accommodates a large number of concepts and relations. Its geometric properties, including negative curvature and exponentially expanding volume, facilitate efficient management of complex mathematical knowledge, supporting enhanced reasoning and autoformalisation within LLM-based systems.

### 2.1 Hybrid Moudle Retrieval

In [7]:
import os
import json
import numpy as np
import re
from sentence_transformers import SentenceTransformer
from web.embeddings import load_embedding
from web.evaluate import poincare_distance
from scipy.spatial.distance import euclidean
from collections import defaultdict

# Avoid parallelism warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Regular expression for valid mathematical concepts: only letters, numbers, underscores, minimum length 2
valid_definiendum_pattern = re.compile(r"^[a-z0-9_]{2,}$")
math_symbols = set("+-*/=∑∏√∫<>∈∉{}[]()")

def clean_definiendum(definiendum):
    """Clean the definiendum, return a valid form or an empty string"""
    if not definiendum or not isinstance(definiendum, str):
        return ""
    
    # First, replace dots with underscores (correcting dot handling)
    definiendum = definiendum.replace(".", "_")
    
    # Remove mathematical symbols
    definiendum = ''.join(c for c in definiendum if c not in math_symbols)
    
    # Remove non-ASCII characters
    definiendum = ''.join(c for c in definiendum if ord(c) < 128)
    
    # Replace spaces with underscores
    definiendum = definiendum.replace(" ", "_")
    
    # Keep only valid characters (letters, numbers, underscores)
    definiendum = re.sub(r'[^A-Za-z0-9_]', '', definiendum)
    
    # Convert to lowercase
    definiendum = definiendum.lower()
    
    # If empty after cleaning, consider invalid
    if not definiendum:
        return ""
    
    return definiendum

def normalize_entity_name(name):
    """Normalize entity name for matching raw data"""
    return "_".join(name.lower().strip().split())

class HybridSemanticSearch:
    def __init__(self, json_path, murp_embedding_path, 
                 semantic_model="sentence-transformers/sentence-t5-large",
                 murp_model_type="poincare"):
        """
        Hybrid retrieval system: semantic search first, then MuRP for secondary retrieval
        
        Args:
            json_path: JSON file path (contains concepts and embeddings)
            murp_embedding_path: MuRP embedding file path
            semantic_model: Sentence Transformer model name
            murp_model_type: MuRP model type ("poincare" or "euclidean")
        """
        self.json_path = json_path
        self.murp_embedding_path = murp_embedding_path
        self.semantic_model_name = semantic_model
        self.murp_model_type = murp_model_type
        
        # Initialize components
        self.semantic_model = None
        self.murp_embeddings = None
        self.semantic_embeddings = None
        self.flattened_concepts = []
        self.murp_vocab = []
        self.original_dataset = None
        
        # Load data
        self.load_semantic_data()
        self.load_semantic_model()
        self.load_murp_embeddings()
    
    def load_semantic_data(self):
        """Load semantic search data"""
        print("Loading semantic search data...")
        
        if not os.path.exists(self.json_path):
            raise FileNotFoundError(f"File not found: {self.json_path}")
        
        with open(self.json_path, "r", encoding="utf-8") as f:
            self.original_dataset = json.load(f)
        
        embeddings_list = []
        
        for module_name, module_content in self.original_dataset.items():
            definitions = module_content.get("definitions", [])
            
            for definition in definitions:
                semantic_analysis = definition.get("semantic_analysis", {})
                concepts = semantic_analysis.get("concepts", [])
                
                for concept in concepts:
                    vec = concept.get("embedding_vector")
                    if vec is not None:
                        embeddings_list.append(vec)
                        concept_copy = concept.copy()
                        if "name" not in concept_copy:
                            concept_copy["name"] = concept_copy.get("concept_name", "Unknown")
                        if "module_path" not in concept_copy:
                            concept_copy["module_path"] = module_name
                        
                        self.flattened_concepts.append(concept_copy)
        
        if embeddings_list:
            self.semantic_embeddings = np.array(embeddings_list)
            print(f"Loaded {len(self.semantic_embeddings)} semantic concept embeddings")
        else:
            raise ValueError("No valid semantic embeddings found")
    
    def load_semantic_model(self):
        """Load Sentence Transformer model"""
        print(f"Loading semantic model: {self.semantic_model_name}")
        try:
            self.semantic_model = SentenceTransformer(self.semantic_model_name)
            print("Semantic model loaded successfully")
        except Exception as e:
            print(f"Failed to load semantic model: {e}")
            raise
    
    def load_murp_embeddings(self):
        """Load MuRP embeddings"""
        print(f"Loading MuRP embeddings from: {self.murp_embedding_path}")
        print(f"Model type: {self.murp_model_type}")
        
        try:
            self.murp_embeddings = load_embedding(
                self.murp_embedding_path,
                format="dict",
                normalize=True,
                lower=True,
                clean_words=False
            )
            
            self.murp_vocab = self.murp_embeddings.words
            print(f"Loaded MuRP embeddings with {len(self.murp_vocab)} vocabulary items")
            
        except Exception as e:
            print(f"Failed to load MuRP embeddings: {e}")
            raise
    
    def semantic_search(self, query, top_k=20, min_score=0.0):
        """Phase 1: Semantic search (with deduplication)"""
        print(f"Phase 1: Semantic search for: '{query}'")
        
        if self.semantic_embeddings is None or len(self.semantic_embeddings) == 0:
            print("No semantic embeddings available")
            return []
        
        try:
            # Generate query embedding
            query_vec = self.semantic_model.encode([query])
            if isinstance(query_vec, list):
                query_vec = np.array(query_vec)
            if query_vec.ndim > 1:
                query_vec = query_vec[0]
            
            # Normalize vectors
            query_vec = query_vec / (np.linalg.norm(query_vec) + 1e-8)
            norm_embeddings = self.semantic_embeddings / (np.linalg.norm(self.semantic_embeddings, axis=1, keepdims=True) + 1e-8)
            
            # Calculate cosine similarity
            cosine_scores = np.dot(norm_embeddings, query_vec)
            
            # Get indices of all results (sorted by similarity descending)
            all_indices = np.argsort(cosine_scores)[::-1]
            
            print(f"Found {len(all_indices)} potential semantic results")
            print(f"Similarity range: {cosine_scores[all_indices[-1]]:.4f} to {cosine_scores[all_indices[0]]:.4f}")
            
        except Exception as e:
            print(f"Semantic search failed: {e}")
            return []
        
        # Collect results (with deduplication)
        results = []
        seen_concepts = set()
        
        for i in all_indices:
            if len(results) >= top_k:
                break
                
            if i < len(self.flattened_concepts):
                concept = self.flattened_concepts[i]
                score = cosine_scores[i]
                
                # Skip scores below threshold
                if score < min_score:
                    continue
                
                # Get and clean concept name
                concept_name = concept.get('name', '')
                cleaned_name = clean_definiendum(concept_name)
                
                # Check for duplicates
                if cleaned_name and cleaned_name not in seen_concepts:
                    # 确保返回的是字典格式，包含 'concept' 键
                    result_dict = {
                        'concept': concept,
                        'semantic_score': float(score),
                        'index': int(i)
                    }
                    results.append(result_dict)
                    seen_concepts.add(cleaned_name)
        
        print(f"{len(results)} unique results above threshold {min_score}")
        return results
    
    def murp_search(self, concept_names, top_k=10):
        """Phase 2: MuRP retrieval"""
        print(f"Phase 2: MuRP search for {len(concept_names)} concepts")
        
        if not self.murp_embeddings:
            print("No MuRP embeddings available")
            return {}
        
        murp_results = {}
        found_concepts = []
        missing_concepts = []
        
        for concept_name in concept_names:
            possible_names = [
                concept_name,
                concept_name.replace('.', '_'),
                concept_name.replace('_', '.'),
                concept_name.split('.')[-1] if '.' in concept_name else concept_name
            ]
            
            query_name = None
            for name in possible_names:
                if name in self.murp_embeddings:
                    query_name = name
                    break
            
            if query_name is None:
                missing_concepts.append(concept_name)
                continue
            
            found_concepts.append((concept_name, query_name))
            
            try:
                neighbors = self.retrieve_murp_neighbors(query_name, top_k)
                murp_results[concept_name] = {
                    'query_name_used': query_name,
                    'neighbors': neighbors
                }
            except Exception as e:
                print(f"MuRP search failed for {concept_name}: {e}")
        
        print(f"Found MuRP results for {len(found_concepts)} concepts")
        if missing_concepts:
            print(f"Missing in MuRP vocab: {len(missing_concepts)} concepts")
        
        return murp_results
    
    def retrieve_murp_neighbors(self, query_word, top_k=10):
        """MuRP neighbor retrieval"""
        if query_word not in self.murp_embeddings:
            raise ValueError(f"Query word '{query_word}' not found in MuRP vocabulary")
        
        query_vec = self.murp_embeddings[query_word]
        distances = []
        
        for word in self.murp_vocab:
            vec = self.murp_embeddings[word]
            if self.murp_model_type == "poincare":
                dist = poincare_distance(query_vec, vec)
            else:
                dist = euclidean(query_vec, vec)
            distances.append((word, dist))
        
        # Sort distances and exclude the query word itself
        distances_sorted = sorted(distances, key=lambda x: x[1])
        distances_sorted = [item for item in distances_sorted if item[0] != query_word]
        
        # Take top-k nearest neighbors
        top_neighbors = distances_sorted[:top_k]
        
        # Normalize distances
        if top_neighbors:
            dists = np.array([dist for _, dist in top_neighbors])
            min_dist, max_dist = dists.min(), dists.max()
            if max_dist > min_dist:
                norm_dists = (dists - min_dist) / (max_dist - min_dist)
            else:
                norm_dists = np.zeros_like(dists)
            
            return [(word, dist, norm) for (word, dist), norm in zip(top_neighbors, norm_dists)]
        
        return []
    
    def hybrid_search(self, query, semantic_top_k=20, murp_top_k=10, semantic_min_score=0.1, return_top_k=10):
        """Hybrid retrieval: semantic search + MuRP secondary retrieval"""
        print(f"Starting hybrid search for: '{query}'")
        
        # Phase 1: Semantic search (with deduplication)
        semantic_results = self.semantic_search(query, semantic_top_k, semantic_min_score)
        
        if not semantic_results:
            print("No semantic results found")
            return {
                "semantic_results": [], 
                "murp_results": {},
                "import_modules": []
            }
        
        # Extract and clean concept names for MuRP retrieval
        raw_concept_names = [result['concept'].get('name', '') for result in semantic_results]
        raw_concept_names = [name for name in raw_concept_names if name and name != 'Unknown']
        
        # Apply cleaning function (includes dot correction)
        cleaned_concept_names = []
        for name in raw_concept_names:
            cleaned = clean_definiendum(name)
            if cleaned:
                cleaned_concept_names.append(cleaned)
        
        if not cleaned_concept_names:
            print("All concept names were invalid after cleaning")
            return {
                "semantic_results": semantic_results,
                "murp_results": {},
                "import_modules": []
            }
        
        print(f"Using {len(cleaned_concept_names)} cleaned concepts for MuRP search")
        
        # Phase 2: MuRP retrieval (using cleaned lowercase names)
        murp_results = self.murp_search(cleaned_concept_names, murp_top_k)
        
        # Create a dictionary to store modules and their relevance scores
        module_scores = defaultdict(float)
        
        # 1. Assign scores to modules from semantic search results
        for result in semantic_results:
            module = result['concept'].get('module_path', '')
            score = result.get('semantic_score', 0)
            if module:
                module_scores[module] += score  # Accumulate scores
        
        # 2. Assign scores to modules from MuRP results (based on distance)
        for concept_name, murp_data in murp_results.items():
            for neighbor, _, norm_dist in murp_data['neighbors']:
                # Find the module to which the neighbor belongs
                neighbor_module = None
                for module_name, module_data in self.original_dataset.items():
                    for definition in module_data.get("definitions", []):
                        def_name = definition.get("name", "")
                        if def_name and clean_definiendum(def_name) == neighbor:
                            neighbor_module = module_name
                            break
                    if neighbor_module:
                        break
                
                if neighbor_module:
                    # Smaller distance means higher score (1 - normalized distance)
                    module_scores[neighbor_module] += (1 - norm_dist)
        
        # 3. Sort by score and select top return_top_k
        sorted_modules = sorted(module_scores.items(), key=lambda x: x[1], reverse=True)
        top_modules = [module for module, score in sorted_modules[:return_top_k]]
        
        # 保持与原来相同的返回值格式
        return {
            "semantic_results": semantic_results,
            "murp_results": murp_results,
            "import_modules": top_modules
        }

# Usage Example
def main():
    # Configure paths
    json_path = "./Informalisation_and_Mathematical_DSRL/merged_with_embeddings_and_triples.json"
    murp_embedding_path = "./Mathlib4_embeddings/outputs_cleaned/embeddings/model_dict_poincare_300_my_dataset_cleaned"

    try:
        # Initialize hybrid search system
        print("Initializing Hybrid Search System...")
        search_system = HybridSemanticSearch(
            json_path=json_path,
            murp_embedding_path=murp_embedding_path,
            semantic_model="sentence-transformers/sentence-t5-large",
            murp_model_type="poincare"
        )
        
        # Execute hybrid search
        query = "For any real matrix A: Matrix m × n, if the columns of A are pairwise orthogonal, then the matrix Aᵀ * A is a diagonal matrix."
        
        results = search_system.hybrid_search(
            query=query,
            semantic_top_k=5,
            murp_top_k=5,
            semantic_min_score=0.1,
            return_top_k=10
        )
        
        # Print results
        print("Hybrid Search Results:")
        print(f"Retrieved {len(results['import_modules'])} modules:")
        for i, module in enumerate(results['import_modules'], 1):
            print(f" {i}. {module}")
        
        # 测试后续代码需要的格式
        print(f"\nTesting format compatibility:")
        print(f"Semantic results type: {type(results['semantic_results'])}")
        print(f"Semantic results length: {len(results['semantic_results'])}")
        
        if results['semantic_results']:
            first_result = results['semantic_results'][0]
            print(f"First result type: {type(first_result)}")
            print(f"First result keys: {first_result.keys() if isinstance(first_result, dict) else 'N/A'}")
            print(f"Has 'concept' key: {'concept' in first_result if isinstance(first_result, dict) else False}")
            
            # 测试后续代码的访问方式
            related_concepts = [result['concept'] for result in results['semantic_results']]
            print(f"Successfully extracted {len(related_concepts)} concepts")
        
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

        
if __name__ == "__main__":
    main()

## 3. Autoformalisation 

### 3.1 Prompt Construction

In [8]:
from string import Template
import numpy as np
import openai
import os
import json


def build_prompt(user_query, matched_concepts, all_imports=None):
    import_modules = set(all_imports) if all_imports else set()
    concept_text_blocks = []

    for c in matched_concepts:
        # 1. Extract module path (if module information comes from the concept itself)
        module_path = c.get("module_path", "")
        if module_path:
            import_modules.add(f"import {module_path}")

        # 2. Extract concept attributes
        concept_name = c.get("name", "")
        signature = c.get("signature", "")
        definition = c.get("definition", "")
        informal = c.get("informal", "")
        semantic_type = c.get("semantic_type", "")
        genus = c.get("genus", "")
        key_lemmas = c.get("Key lemmas", [])
        properties = c.get("properties", [])
        triples = c.get("triples", [])

        # 3. Format key lemmas
        lemmas_text = "\n".join(f"- {lemma}" for lemma in key_lemmas) if key_lemmas else "N/A"

        # 4. Format properties
        props_text = "\n".join(f"- {p}" for p in properties) if properties else "N/A"

        # 5. Format triples
        triples_text = (
            "\n".join(
                f"- {t.get('subject')} --[{t.get('role')}]--> {t.get('object')}"
                for t in triples if all(k in t for k in ("subject", "role", "object"))
            ) if triples else "N/A"
        )

        # 6. Concept text block
        block = (
            f"Concept: {concept_name}\n"
            f"Lean Signature: {signature}\n"
            f"Definition: {definition}\n"
            f"Informal: {informal}\n"
            f"Genus: {genus}\n"
            f"Type: {semantic_type}\n"
            f"Key Lemmas:\n{lemmas_text}\n"
            f"Properties:\n{props_text}\n"
            f"Semantic Triples:\n{triples_text}"
        )

        concept_text_blocks.append(block)

    # Concatenate import block
    import_block = "\n".join(sorted(import_modules))
    # Concatenate context block
    context = "\n\n---\n\n".join(concept_text_blocks)

    # Read prompt template
    from string import Template
    with open("./prompts/prompt_template.txt", "r", encoding="utf-8") as f:
        template_str = f.read()
    prompt_template = Template(template_str)

    # Replace template variables
    filled_prompt = prompt_template.substitute(
        import_block=import_block,
        user_query=user_query,
        context=context,
        matched_concepts=context
    )
    return filled_prompt


# 3. Call OpenAI API with exception handling
def ask_openai(prompt):
    try:
        response = openai.chat.completions.create(
            model="gpt-4o",
            messages=[
                {"role": "system", "content": "You are an expert in mathematics and Lean 4. You will help formalize mathematical statements."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.2,
            max_tokens=800
        )
        return response.choices[0].message.content
    except Exception as e:
        print(f"OpenAI API call failed: {e}")
        return ""

# 4. Main process
def rag_autoformalize(user_query, top_k=5, all_imports=None):
    related_concepts = search_similar_concepts(user_query, top_k)
    prompt = build_prompt(user_query, related_concepts, all_imports=all_imports)
    result = ask_openai(prompt)
    return result

### 3.2 Preparation for Lean Code Generation

In [9]:
import subprocess
import os
import tempfile
from pathlib import Path

def clean_output(output: str) -> str:
    """Remove markdown code block markers"""
    text = output.strip()
    if text.startswith("```lean"):
        text = text[text.find('\n')+1:]
        if text.endswith("```"):
            text = text[:text.rfind("```")]
    return text.strip()


def find_lean_project_root(start_path: str = ".") -> str:
    """
    Search upwards for Lean project root directory (directory containing lakefile.lean)
    """
    current = Path(start_path).resolve()
    
    while current != current.parent:
        if (current / "lakefile.lean").exists():
            return str(current)
        current = current.parent
    
    # If not found, return current directory
    return os.getcwd()


def validate_lean_output(content: str = None, file_path: str = "auto_output.lean", 
                        project_dir: str = None) -> bool:
    """
    Improved Lean validation function, automatically handles project paths and dependencies
    """
    # Determine project directory
    if project_dir is None:
        project_dir = find_lean_project_root()
    
    # If content is provided, write to file
    if content is not None:
        full_path = os.path.join(project_dir, file_path)
        with open(full_path, "w", encoding="utf-8") as f:
            f.write(content)
    
    # Try different validation methods
    validation_methods = [
        ("lake env lean", ["lake", "env", "lean", file_path]),
        ("lean", ["lean", file_path]),
        ("lake build specific", ["lake", "build", file_path.replace('.lean', '')])
    ]
    
    for method_name, command in validation_methods:
        try:
            print(f"Trying validation method: {method_name}")
            
            result = subprocess.run(
                command,
                capture_output=True,
                text=True,
                timeout=30,
                cwd=project_dir
            )
            
            if result.returncode == 0:
                print(f"Lean 4 validation passed using {method_name}")
                return True
            else:
                print(f"{method_name} failed:")
                if result.stdout.strip():
                    print(f"STDOUT: {result.stdout.strip()}")
                if result.stderr.strip():
                    print(f"STDERR: {result.stderr.strip()}")
                    
                # If it's a Mathlib path issue, try next method
                if "unknown module prefix 'Mathlib'" in result.stderr:
                    print("Mathlib not found, trying next method...")
                    continue
                else:
                    # Other errors, return failure directly
                    return False
                    
        except subprocess.TimeoutExpired:
            print(f"{method_name} timed out, trying next method...")
            continue
        except FileNotFoundError:
            print(f"Command not found: {command[0]}, trying next method...")
            continue
        except Exception as e:
            print(f"Exception with {method_name}: {e}")
            continue
    
    print("All validation methods failed")
    return False


def setup_lean_environment(project_dir: str):
    """
    Set up Lean environment, download dependencies
    """
    try:
        print("Setting up Lean environment...")
        
        # Try to get cached dependencies
        cache_result = subprocess.run(
            ["lake", "exe", "cache", "get"],
            capture_output=True,
            text=True,
            timeout=300,
            cwd=project_dir
        )
        
        if cache_result.returncode != 0:
            print("Cache get failed, trying lake update...")
            # If cache fails, try updating
            update_result = subprocess.run(
                ["lake", "update"],
                capture_output=True,
                text=True,
                timeout=600,
                cwd=project_dir
            )
            
            if update_result.returncode != 0:
                print(f"Lake update failed: {update_result.stderr}")
                return False
        
        print("Lean environment setup complete")
        return True
        
    except Exception as e:
        print(f"Failed to setup Lean environment: {e}")
        return False


def iterative_refine(user_query, max_iters=20, project_dir=None):
    if project_dir is None:
        project_dir = find_lean_project_root()
    
    if not setup_lean_environment(project_dir):
        print("Warning: Lean environment setup failed, continuing anyway...")
    
    last_output = None
    
    for i in range(max_iters):
        print(f"\n--- Iteration {i+1} ---")
        
        if i == 0:
            prompt = build_prompt(user_query, search_similar_concepts(user_query, top_k=5))
        else:
            prompt = (
                build_prompt(user_query, search_similar_concepts(user_query, top_k=5))
                + f"\n\n# The previous output had validation errors. Please fix:\n{last_output}\n\n# Provide a corrected version:"
            )
        
        output = ask_openai(prompt)
        cleaned = clean_output(output)

        if validate_lean_output(cleaned, "auto_output.lean", project_dir):
            print(f"Iteration {i+1}: Valid output found!")
            return cleaned
        else:
            print(f"Iteration {i+1}: Output invalid, retrying...")
            last_output = cleaned
    
    print("Max iterations reached, returning last output.")
    return last_output or ""

### 3.3. Hybrid RAG Autoformalisation

In [14]:
# Modify the rag_autoformalize function to accept search_system parameter
def rag_autoformalize(user_query, search_system, top_k=5, all_imports=None):
    # Use hybrid search system to retrieve related concepts and modules
    hybrid_results = search_system.hybrid_search(user_query, semantic_top_k=top_k)
    
    # Extract related concepts from hybrid search results
    related_concepts = [result['concept'] for result in hybrid_results['semantic_results']]
    
    # Extract import modules from hybrid search results (if all_imports is not provided)
    if all_imports is None:
        all_imports = hybrid_results['import_modules']
    
    # Build prompt and call OpenAI
    prompt = build_prompt(user_query, related_concepts, all_imports=all_imports)
    result = ask_openai(prompt)
    return result


def main():
    # Configure paths
    json_path = "./Informalisation_and_Mathematical_DSRL/merged_with_embeddings_and_triples.json"
    murp_embedding_path = "./Mathlib4_embeddings/outputs_cleaned/embeddings/model_dict_poincare_300_my_dataset_cleaned"

    try:
        # Initialize hybrid search system
        print("Initializing Hybrid Search System...")
        search_system = HybridSemanticSearch(
            json_path=json_path,
            murp_embedding_path=murp_embedding_path,
            semantic_model="sentence-transformers/sentence-t5-large",
            murp_model_type="poincare"
        )
        
        # Execute hybrid search and get autoformalization results
        query = "For any real matrix A: Matrix m × n, if the columns of A are pairwise orthogonal, then the matrix Aᵀ * A is a diagonal matrix."
        
        # Use hybrid search system to get autoformalization results
        output = rag_autoformalize(query, search_system, top_k=5)
        
        # Clean output
        if output.strip().startswith("```lean"):
            output = output.strip()
            output = output[output.find('\n')+1:]  # Remove first line ```lean
            if output.endswith("```"):
                output = output[:output.rfind("```")]
        
        # Write to file
        with open("auto_output.lean", "w", encoding="utf-8") as f:
            f.write(output)

        print(output)
        
        # Validate output
        result = validate_lean_output(output)
        print(f"Validation result: {result}")
        
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    main()

### RAG-Based Autoformalization with Iterative Refinement

In [None]:
def iterative_refine(user_query, search_system, max_iters=20, project_dir=None):
    """
    Iterative refinement function using hybrid search system
    
    Args:
        user_query: User query
        search_system: Hybrid search system instance
        max_iters: Maximum number of iterations
        project_dir: Lean project directory
    """
    if project_dir is None:
        project_dir = find_lean_project_root()
    
    if not setup_lean_environment(project_dir):
        print("Warning: Lean environment setup failed, continuing anyway...")

    last_output = None
    
    for i in range(max_iters):
        print(f"\n--- Iteration {i+1} ---")
        
        # Use hybrid search system to get results
        hybrid_results = search_system.hybrid_search(user_query, semantic_top_k=5)
        related_concepts = [result['concept'] for result in hybrid_results['semantic_results']]
        all_imports = hybrid_results['import_modules']
        
        if i == 0:
            prompt = build_prompt(user_query, related_concepts, all_imports=all_imports)
        else:
            prompt = (
                build_prompt(user_query, related_concepts, all_imports=all_imports)
                + f"\n\n# The previous output had validation errors. Please fix:\n{last_output}\n\n# Provide a corrected version:"
            )
        
        output = ask_openai(prompt)
        cleaned = clean_output(output)

        if validate_lean_output(cleaned, "auto_output.lean", project_dir):
            print(f"Iteration {i+1}: Valid output found!")
            return cleaned
        else:
            print(f"Iteration {i+1}: Output invalid, retrying...")
            last_output = cleaned
    
    print("⚠️ Max iterations reached, returning last output.")
    return last_output or ""


if __name__ == "__main__":
    # Configure paths
    json_path = "./Informalisation_and_Mathematical_DSRL/merged_with_embeddings_and_triples.json"
    murp_embedding_path = "./Mathlib4_embeddings/outputs_cleaned/embeddings/model_dict_poincare_300_my_dataset_cleaned"

    try:
        # Initialize hybrid search system
        print("Initializing Hybrid Search System...")
        search_system = HybridSemanticSearch(
            json_path=json_path,
            murp_embedding_path=murp_embedding_path,
            semantic_model="sentence-transformers/sentence-t5-large",
            murp_model_type="poincare"
        )
        
        # Natural language description of matrix associativity
        matrix_statement = (
           "For any real matrix A: Matrix m × n, if the columns of A are pairwise orthogonal, then the matrix Aᵀ * A is a diagonal matrix."
           # "For any two integers a and b, if a divides b and b divides a, then a is equal to b or a is equal to -b."
        )
        
        # Call iterative refinement, passing hybrid search system instance
        final_code = iterative_refine(
            user_query=matrix_statement, 
            search_system=search_system, 
            max_iters=5
        )
        
        # Directly validate final output
        result = validate_lean_output(final_code)
        print(f"Validation result: {result}")
        
        # Print final code
        print(f"\nFinal validated code:\n{final_code}")
        
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()