# Mathematical Definition Semantic Extraction Pipeline

In [None]:
import openai
OPENAI_API_KEY = ""  # Replace with your actual OpenAI API key
openai.api_key = OPENAI_API_KEY

In [None]:
import os
import json
import openai
import time
import re
from collections import defaultdict

# --- Configuration Section ---
mathlib_root_dir = "path/to/mathlib4/Mathlib" # change this to your mathlib4 directory, download the repository first https://github.com/leanprover-community/mathlib4.git
output_file = "./informal_data/informal_linear_algebra_Finsupp.json"

# Get all modules in LinearAlgebra directory
def get_all_linear_algebra_modules():
    """Get all module paths under Mathlib.LinearAlgebra directory"""
    linear_algebra_dir = os.path.join(mathlib_root_dir, "LinearAlgebra/Finsupp")
    modules = set()
    
    for root, dirs, files in os.walk(linear_algebra_dir):
        for file in files:
            if file.endswith(".lean"):
                full_path = os.path.join(root, file)
                module_name = get_module_path(full_path)
                if module_name.startswith("Mathlib.LinearAlgebra.Finsupp"):
                    modules.add(module_name)
    
    return sorted(modules)

# Convert module name to file path
def module_to_filepath(module_name):
    """Convert module name to file path, supporting versioned files"""
    # Remove possible version suffix
    clean_name = re.sub(r'\.v\d+$', '', module_name)
    
    # Build relative path
    relative_path = clean_name.replace("Mathlib.", "").replace(".", os.sep) + ".lean"
    full_path = os.path.join(mathlib_root_dir, relative_path)
    
    # Check if file exists
    if os.path.exists(full_path):
        return full_path
    
    # Try versioned path
    versioned_path = full_path.replace(".lean", ".v1.lean")
    if os.path.exists(versioned_path):
        return versioned_path
    
    print(f"⚠️ File does not exist: {full_path}")
    print(f"Current working directory: {os.getcwd()}")
    print(f"Mathlib root directory: {os.path.abspath(mathlib_root_dir)}")
    return None

# Get module name from file path
def get_module_path(file_path):
    """Get module name from file path"""
    try:
        # Get path relative to mathlib_root_dir
        relative_path = os.path.relpath(file_path, start=mathlib_root_dir)
        
        # Remove extension and version suffix
        module_path = os.path.splitext(relative_path)[0]
        module_path = re.sub(r'\.v\d+$', '', module_path)
        
        # Replace path separators and add Mathlib prefix
        return "Mathlib." + module_path.replace(os.sep, '.')
    except ValueError:
        return file_path

# Extract file content
def extract_file_content(filepath):
    if not filepath or not os.path.exists(filepath):
        return ""
    try:
        with open(filepath, "r", encoding="utf-8") as f:
            return f.read()
    except Exception as e:
        print(f"Failed to read file: {filepath} - {e}")
        return ""

# Extract module docstring
def extract_module_docstring(content):
    if not content:
        return ""
    try:
        match = re.search(r'/-!(.*?)-/s', content, re.DOTALL)
        if match:
            doc = match.group(1).strip()
            doc = re.sub(r'^[/*\s-]+', '', doc, flags=re.MULTILINE)
            doc = re.sub(r'[/*\s-]+$', '', doc, flags=re.MULTILINE)
            return doc
        return ""
    except Exception as e:
        print(f"Failed to extract documentation: {e}")
        return ""

# Extract imports
def extract_imports(content):
    if not content:
        return []
    try:
        imports = set()
        for match in re.finditer(r'import\s+([\w\.]+)(?:\s+--.*?$)?', content, re.MULTILINE):
            imp = match.group(1).strip()
            if imp:
                imports.add(imp)
        return list(imports)
    except Exception as e:
        print(f"⚠️ Failed to extract imports: {e}")
        return []

# Extract definitions from content
def extract_definitions_from_content(content):
    pattern = r'(^\s*(?:def|class|structure|abbrev|notation|lemma|theorem)\s+\w+[\s\S]*?)(?=\s*^(?:def|class|structure|abbrev|notation|lemma|theorem)\s+\w+|\Z)'
    return [m.group(1).strip() for m in re.finditer(pattern, content, re.MULTILINE | re.DOTALL)]

# ================================================================
# Semantic Analysis Agent (extracts basic semantic info)
# ================================================================
BASIC_SEMANTIC_ANALYSIS_PROMPT = """
You are a mathematical definition semantic analysis expert. Process definitions from Lean files and output in strict JSON format. Extract all definitions and concepts from a Lean file:

### Output Requirements (Strict JSON Format)
{{
  "definitions": [
    {{
      "name": "Definition name",
      "type": "Definition type (structure/class/def etc.)",
      "signature": "Complete definition content in Lean4 code extracted from lean files",
      "body": "Complete definition content",
      "semantic_analysis": {{
      "informal": "Concise natural language explanation of the core concepts",
        
        "concepts": [
          {{
            "name": "Mathematical Concepts and Terminologies name ",
            "informal_definition": "Generate a scientifically rigorous and precise natural language description of the core concepts",
            "signature": "Complete definition content in Lean4 code extracted from lean files",
            ]
          }}
        ]
      }}
    }}
  ]
}}

### Instructions
1. Focus exclusively on the formal definition content
2. Generate concise, non-redundant explanations
3. Omit all theorem proofs and implementation details
4. Output must be valid JSON
"""

# =====================================
# API Call Functions
# =====================================
def call_gpt_api(prompt, max_retries=3, model="gpt-4o"):
    """Handles GPT API calls with retry logic"""
    for attempt in range(max_retries):
        try:
            response = openai.chat.completions.create(
                model=model,
                messages=[
                    {"role": "system", "content": "You are an expert in mathematical semantic analysis"},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.2,
                max_tokens=2048,
                response_format={"type": "json_object"}
            )
            return response.choices[0].message.content.strip()
        except openai.RateLimitError:
            delay = (attempt + 1) * 5
            print(f"Rate limit hit. Retrying in {delay}s... ({attempt+1}/{max_retries})")
            time.sleep(delay)
        except Exception as e:
            print(f"API Error: {str(e)}")
            if attempt < max_retries - 1:
                time.sleep(5)
            else:
                return ""
    return ""

def generate_basic_semantic_analysis(definition):
    """Extracts basic semantic information"""
    prompt = BASIC_SEMANTIC_ANALYSIS_PROMPT + f"\n### Input Lean Definition Content\n{definition}"
    return call_gpt_api(prompt)

# =====================================
# Utility Functions
# =====================================
def extract_json_from_response(response_text):
    if not response_text:
        return None
    try:
        return json.loads(response_text)
    except json.JSONDecodeError as e:
        print(f"Direct parsing failed: {e}")
    try:
        start_idx = response_text.find('{')
        end_idx = response_text.rfind('}')
        if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
            json_str = response_text[start_idx:end_idx+1]
            return json.loads(json_str)
    except Exception as e:
        print(f"Fragment extraction failed: {e}")
    return None

# =====================================
# Main Processing Pipeline
# =====================================
def main():
    # Dynamically get all LinearAlgebra modules
    target_modules = get_all_linear_algebra_modules()
    
    if not target_modules:
        print("❌ No modules found under Mathlib.LinearAlgebra")
        return
    
    print(f"🔍 Found {len(target_modules)} LinearAlgebra modules")
    
    grouped_results = defaultdict(lambda: {
        "docstring": "",
        "dependencies": [],
        "definitions": []
    })
    total_def_count = 0
    total_concept_count = 0

    print("Starting processing of LinearAlgebra modules...\n")

    for module_index, module_name in enumerate(target_modules):
        path = module_to_filepath(module_name)
        if not path or not os.path.isfile(path):
            print(f"File does not exist: {module_name}")
            continue

        module_path = get_module_path(path)
        file_content = extract_file_content(path)
        if not file_content:
            print(f"File content is empty: {path}")
            continue

        print(f"\n{'='*50}")
        print(f"Processing module ({module_index+1}/{len(target_modules)}): {module_name}")
        print(f"Path: {path}")
        print(f"{'='*50}")

        module_doc = extract_module_docstring(file_content)
        dependencies = extract_imports(file_content)
        definitions = extract_definitions_from_content(file_content)
        
        print(f"Found {len(definitions)} definitions")

        for def_index, def_body in enumerate(definitions):
            print(f"\n[Definition {def_index+1}/{len(definitions)}] Processing...")
            
            # Semantic analysis
            print("  Invoking Semantic Analysis Agent...")
            basic_response = generate_basic_semantic_analysis(def_body)
            basic_data = extract_json_from_response(basic_response)
            
            if not basic_data or "definitions" not in basic_data:
                print("  ⚠️ Agent did not return valid data. Skipping definition.")
                continue
                
            for def_entry in basic_data["definitions"]:
                concepts = def_entry.get("semantic_analysis", {}).get("concepts", [])
                print(f"  Found {len(concepts)} concepts")
                total_concept_count += len(concepts)
                
                # Add to final results
                clean_definition = {
                    "name": def_entry.get("name", ""),
                    "type": def_entry.get("type", ""),
                    "signature": def_entry.get("signature", ""),
                    "body": def_entry.get("body", ""),
                    "semantic_analysis": def_entry.get("semantic_analysis", {})
                }
                grouped_results[module_path]["definitions"].append(clean_definition)
                total_def_count += 1
            
            # Rate control between definitions
            time.sleep(1.5)

        # Add module metadata
        grouped_results[module_path]["docstring"] = module_doc
        grouped_results[module_path]["dependencies"] = dependencies
        print(f"✓ Completed module: {module_name}")

    # Write final JSON file
    print(f"\nWriting JSON file: {output_file}")
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(grouped_results, f, indent=2, ensure_ascii=False)
 
    print(f"\n{'='*50}")
    print("Processing Complete!")
    print(f"  Modules processed: {len(grouped_results)}")
    print(f"  Definitions extracted: {total_def_count}")
    print(f"  Concepts analyzed: {total_concept_count}")
    print(f"  Output file: {os.path.abspath(output_file)}")
    print(f"{'='*50}")

if __name__ == "__main__":
    main()

## Counting

In [None]:
import os
import glob
import json

def count_definitions_concepts(json_path):
    definitions_count = 0
    concepts_count = 0

    with open(json_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    for module_name, module_content in data.items():
        definitions = module_content.get("definitions", [])
        definitions_count += len(definitions)

        for definition in definitions:
            semantic = definition.get("semantic_analysis", {})
            concepts = semantic.get("concepts", [])
            concepts_count += len(concepts)

    return definitions_count, concepts_count

if __name__ == "__main__":
    json_folder = "./informal_data/"
    json_files = glob.glob(os.path.join(json_folder, "*.json"))

    total_definitions = 0
    total_concepts = 0

    for json_path in json_files:
        defs, cons = count_definitions_concepts(json_path)
        total_definitions += defs
        total_concepts += cons

    print(f"Total definitions in all JSON files in the folder: {total_definitions}")
    print(f"Total concepts in all JSON files in the folder: {total_concepts}")