In [None]:
import os
import json
import re
import time
import random
from collections import defaultdict
from openai import OpenAI  # Using the new OpenAI client

# Define a valid regex pattern for mathematical concepts: only letters, digits, underscores, length at least 2
valid_definiendum_pattern = re.compile(r"^[a-z0-9_]{2,}$")
math_symbols = set("+-*/=∑∏√∫<>∈∉{}[]()")

def clean_definiendum(definiendum):
    """Clean definiendum, returning a valid form or empty string."""
    if not definiendum or not isinstance(definiendum, str):
        return ""
    
    # Replace periods with underscores
    definiendum = definiendum.replace(".", "_")
    
    # Remove mathematical symbols
    definiendum = ''.join(c for c in definiendum if c not in math_symbols)
    
    # Remove non-ASCII characters
    definiendum = ''.join(c for c in definiendum if ord(c) < 128)
    
    # Replace spaces with underscores
    definiendum = definiendum.replace(" ", "_")
    
    # Retain only valid characters (letters, digits, underscores)
    definiendum = re.sub(r'[^A-Za-z0-9_]', '', definiendum)
    
    # Convert to lowercase
    definiendum = definiendum.lower()
    
    if not definiendum:
        return ""
    
    return definiendum

def normalize_entity_name(name):
    """Standardize entity names for matching in the original dataset."""
    return "_".join(name.lower().strip().split())

class GPTBaseline:
    """GPT retrieval baseline: directly retrieves relevant Mathlib4 modules using GPT."""
    def __init__(self, api_key=None, model="gpt-4o", verbose=True):
        self.model = model
        self.verbose = verbose
        
        # Set OpenAI API key
        if api_key:
            self.client = OpenAI(api_key=api_key)
        elif "OPENAI_API_KEY" in os.environ:
            self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
        else:
            raise ValueError("OpenAI API key not provided and OPENAI_API_KEY environment variable not set")
    
    def log(self, message, level="info"):
        """Log messages according to verbosity settings."""
        if self.verbose or level == "error":
            print(message)
    
    def search_modules(self, query, max_retries=3):
        """Retrieve relevant Mathlib4 modules for a query using GPT."""
        if self.verbose:
            self.log(f"GPT searching for: '{query[:50]}...'")
        prompt = f"""
        You are an expert in the Lean 4 Theorem Prover and Mathlib4 library. 
        Identify the 10 most relevant Mathlib4 modules needed to formalize a given mathematical concept in Lean 4.

        Query: {query}

        Return EXACTLY 10 module names in a JSON array format.
        Only include modules that exist in Mathlib4.
        Rank modules by relevance, most essential first.
        Prioritize core definitions, key theorems, algebraic structures, notation, and important constructions.
        """
        
        for attempt in range(max_retries):
            try:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=[
                        {"role": "system", "content": "You are a helpful assistant specialized in Lean and Mathlib4."},
                        {"role": "user", "content": prompt}
                    ],
                    temperature=0.1,
                    max_tokens=500
                )
                
                # Extract content
                content = response.choices[0].message.content.strip()
                
                # Try parsing JSON
                try:
                    json_match = re.search(r'\[.*\]', content, re.DOTALL)
                    if json_match:
                        modules = json.loads(json_match.group())
                    else:
                        modules = json.loads(content)
                    
                    if self.verbose:
                        self.log(f"GPT found {len(modules)} modules: {modules}")
                    
                    return modules
                
                except json.JSONDecodeError as e:
                    if attempt == max_retries - 1:
                        self.log(f"Failed to parse GPT response after {max_retries} attempts: {e}", "error")
                        module_pattern = r'[A-Za-z][A-Za-z0-9_]*(\.[A-Za-z][A-Za-z0-9_]*)*'
                        potential_modules = re.findall(module_pattern, content)
                        return list(set(potential_modules))[:5]
                    
                    time.sleep(1)
            
            except Exception as e:
                if attempt == max_retries - 1:
                    self.log(f"GPT search failed after {max_retries} attempts: {e}", "error")
                    return []
                time.sleep(2)
        
        return []

class GPTEvaluator:
    """GPT Baseline evaluator - aligned with original experimental design."""
    def __init__(self, json_path, openai_api_key=None, sample_size=50, random_seed=42):
        self.json_path = json_path
        self.openai_api_key = openai_api_key
        self.sample_size = sample_size
        self.random_seed = random_seed
        self.evaluation_data = None
        
        random.seed(random_seed)
        self.prepare_evaluation_data()
    
    def prepare_evaluation_data(self):
        """Prepare evaluation samples from JSON data."""
        print("\nPreparing evaluation data...")
        if not os.path.exists(self.json_path):
            raise FileNotFoundError(f"File not found: {self.json_path}")
        
        with open(self.json_path, "r", encoding="utf-8") as f:
            original_dataset = json.load(f)
        
        all_definitions = []
        for module_name, module_data in original_dataset.items():
            for definition in module_data.get("definitions", []):
                if not isinstance(definition, dict):
                    continue
                semantic_analysis = definition.get("semantic_analysis", {})
                informal = semantic_analysis.get("informal", "") if isinstance(semantic_analysis, dict) else ""
                def_name = definition.get("name", "")
                if informal and def_name:
                    all_definitions.append({
                        "module": module_name,
                        "name": def_name,
                        "query": informal
                    })
        
        print(f"Found {len(all_definitions)} valid definitions with informal descriptions")
        
        if len(all_definitions) < self.sample_size:
            print(f"Only {len(all_definitions)} definitions available, using all")
            self.sample_size = len(all_definitions)
        
        random.shuffle(all_definitions)
        self.evaluation_data = all_definitions[:self.sample_size]
        
        print(f"Prepared {len(self.evaluation_data)} evaluation samples")
        return self.evaluation_data
    
    def find_entity_details(self, entities):
        """Find the module and dependencies of each entity in the dataset."""
        entity_details = {}
        if not os.path.exists(self.json_path):
            raise FileNotFoundError(f"File not found: {self.json_path}")
        
        with open(self.json_path, "r", encoding="utf-8") as f:
            original_dataset = json.load(f)
        
        normalized_entities = {normalize_entity_name(ent): ent for ent in entities}
        norm_keys = set(normalized_entities.keys())
        
        for module_name, module_data in original_dataset.items():
            module_dependencies = module_data.get("dependencies", [])
            for definition in module_data.get("definitions", []):
                if not isinstance(definition, dict):
                    continue
                raw_name = definition.get("concept_name") or definition.get("name")
                if not raw_name:
                    continue
                norm_name = normalize_entity_name(raw_name)
                if norm_name in norm_keys:
                    original_entity = normalized_entities[norm_name]
                    entity_details[original_entity] = {
                        "module": module_name,
                        "dependencies": module_dependencies
                    }
                    
        return entity_details
    
    def evaluate_gpt_baseline(self):
        """Evaluate GPT baseline performance."""
        if not self.evaluation_data:
            self.prepare_evaluation_data()
        
        sample_size = len(self.evaluation_data)
        print(f"\nEvaluating GPT Baseline with {sample_size} samples...")
        print("=" * 80)
        
        gpt_model = GPTBaseline(api_key=self.openai_api_key, model="gpt-4o", verbose=True)
        
        results = {
            "model_name": "GPT Baseline",
            "total_samples": sample_size,
            "correct_predictions": 0,
            "incorrect_predictions": 0,
            "module_recall": 0.0,
            "query_times": [],
            "module_coverage": defaultdict(int),
            "detailed_results": []
        }
        
        progress_interval = max(1, sample_size // 10)
        
        for i, sample in enumerate(self.evaluation_data):
            if i % progress_interval == 0:
                print(f"Processing sample {i+1}/{sample_size} ({((i+1)/sample_size)*100:.1f}%)")
            
            start_time = time.time()
            
            try:
                retrieved_modules = gpt_model.search_modules(sample['query'])
                elapsed_time = time.time() - start_time
                results["query_times"].append(elapsed_time)
                
                retrieved_modules_set = set(retrieved_modules)
                target_module = sample["module"]
                is_correct = target_module in retrieved_modules_set
                
                if is_correct:
                    results["correct_predictions"] += 1
                    results["module_coverage"][target_module] += 1
                else:
                    results["incorrect_predictions"] += 1
                
                detailed_result = {
                    "sample_id": i,
                    "definition_name": sample["name"],
                    "target_module": target_module,
                    "query": sample["query"],
                    "retrieved_modules": list(retrieved_modules_set),
                    "is_correct": is_correct,
                    "time_taken": elapsed_time
                }
                results["detailed_results"].append(detailed_result)
                time.sleep(1)
                
            except Exception as e:
                elapsed_time = time.time() - start_time
                results["query_times"].append(elapsed_time)
                results["incorrect_predictions"] += 1
                print(f"Error during search for sample {i+1}: {e}")
                detailed_result = {
                    "sample_id": i,
                    "definition_name": sample["name"],
                    "target_module": target_module,
                    "query": sample["query"],
                    "error": str(e),
                    "is_correct": False,
                    "time_taken": elapsed_time
                }
                results["detailed_results"].append(detailed_result)
        
        results["module_recall"] = results["correct_predictions"] / sample_size if sample_size > 0 else 0
        results["avg_query_time"] = sum(results["query_times"]) / sample_size if sample_size > 0 else 0
        
        total_correct = results["correct_predictions"]
        if total_correct > 0:
            for module, count in results["module_coverage"].items():
                results["module_coverage"][module] = count / total_correct
        
        print("\n" + "=" * 80)
        print("GPT BASELINE EVALUATION SUMMARY")
        print("=" * 80)
        print(f"Total Samples: {results['total_samples']}")
        print(f"Correct Predictions: {results['correct_predictions']}")
        print(f"Incorrect Predictions: {results['incorrect_predictions']}")
        print(f"Module Recall: {results['module_recall']:.4f}")
        print(f"Average Query Time: {results['avg_query_time']:.2f} seconds")
        
        if results["module_coverage"]:
            print("\nTop Modules by Coverage:")
            sorted_modules = sorted(results["module_coverage"].items(), key=lambda x: x[1], reverse=True)[:5]
            for module, coverage in sorted_modules:
                print(f"  {module}: {coverage:.2f}")
        
        return results
    
    def save_evaluation_results(self, results, output_path="gpt_baseline_evaluation.json"):
        """Save evaluation results to JSON."""
        print(f"\nSaving evaluation results to {output_path}")
        save_data = {
            "evaluation_summary": {
                "model_name": results["model_name"],
                "total_samples": results["total_samples"],
                "correct_predictions": results["correct_predictions"],
                "incorrect_predictions": results["incorrect_predictions"],
                "module_recall": results["module_recall"],
                "avg_query_time": results["avg_query_time"],
                "module_coverage": dict(results["module_coverage"])
            },
            "detailed_results": results["detailed_results"]
        }
        
        try:
            with open(output_path, "w", encoding="utf-8") as f:
                json.dump(save_data, f, indent=2, ensure_ascii=False)
            print("Evaluation results saved successfully")
        except Exception as e:
            print(f"Failed to save evaluation results: {e}")


# ----------------------
# Example usage
# ----------------------
def main():
    json_path = "./Informalisation_and_Mathematical_DSRL/merged_with_embeddings_and_triples.json"
    OPENAI_API_KEY = "<YOUR_API_KEY>"  # Replace with your actual OpenAI API key
    
    try:
        print("Initializing GPT Baseline Evaluator...")
        evaluator = GPTEvaluator(json_path, OPENAI_API_KEY, sample_size=10)
        
        print("\nStarting GPT Baseline evaluation...")
        gpt_results = evaluator.evaluate_gpt_baseline()
        
        evaluator.save_evaluation_results(gpt_results)
        
        print("\n" + "=" * 80)
        print("Running single query example with GPT Baseline...")
        gpt_model = GPTBaseline(api_key=OPENAI_API_KEY, model="gpt-4o", verbose=True)
        
        query = "For any real matrix A: Matrix m × n, if the columns of A are pairwise orthogonal, then the matrix Aᵀ * A is a diagonal matrix."
        gpt_modules = gpt_model.search_modules(query)
        
        print("\nRetrieved modules:")
        for module in gpt_modules:
            print(f" - {module}")
        
    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()
