In [1]:
%pip install --upgrade --quiet openai

Note: you may need to restart the kernel to use updated packages.


In [2]:
import nest_asyncio
nest_asyncio.apply()

In [3]:
import os
from PIL import Image
from IPython.display import HTML, Image, Markdown, display

In [4]:
import json
import time
import re
from typing import Dict, List, Any, Set, Optional
import PIL.Image
from pydantic import BaseModel, ValidationError
import asyncio
import aiohttp
import base64
from openai import OpenAI


In [None]:
os.environ["OPENAI_API_KEY"] = 'add_secret_key'

In [6]:
MODEL_ID = "o3-2025-04-16"

The last thing we need to configure is to fill in the model ID. Currently the latest version of APS April 17 official name is as in cell below. All model names can be checked in the model card on Kaggle.

In [7]:
import json
import pandas as pd
def load_public_data():
    # with open('///mnt/c/Personal/Competitions/ICML_Track2/input/mini.json', 'r') as file:
    with open('///mnt/c/Personal/Competitions/ICML_Track2/input/starting_kit_latest/total.json', 'r') as file:
        data = json.load(file)
    data = pd.DataFrame(data)
    problems = data.to_dict('records')
    return problems
problems = load_public_data()
# problems

In [8]:
import os
import json
import time
import base64
import re
import logging
from typing import Dict, List, Any, Optional, TypedDict, Annotated, Union, Set, Tuple
from enum import Enum
from dataclasses import dataclass
from tqdm.notebook import tqdm

from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, BaseMessage
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from langchain_core.runnables import RunnableConfig
import operator

In [9]:
# Setup logging
def setup_logging(log_file: str = "total_physics_solver.log", verbose: bool = True):
    """Setup logging configuration"""
    os.makedirs(os.path.dirname(log_file) if os.path.dirname(log_file) else '.', exist_ok=True)
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file, encoding='utf-8'),
            logging.StreamHandler() if verbose else logging.NullHandler()
        ],
        force=True
    )
    return logging.getLogger(__name__)

# Physics Domain Definitions (same as before)
PHYSICS_SUBJECTS = {
    'O': 'Optics (Basic)',
    'OPT': 'Optics (Extended/Advanced)', 
    'EM': 'Electromagnetism',
    'CM': 'Classical Mechanics',
    'TSM': 'Thermodynamics & Statistical Mechanics',
    'QMIT': 'Quantum Mechanics & Information Theory',
    'ACG': 'Astrophysics, Cosmology & Gravitation',
    'AMONP': 'Atomic, Molecular, Optical & Nuclear Physics'
}

PHYSICS_CATEGORIES = {
    'MECHANICS': [
        'static_force_analysis', 'spring_force', 'circular_motion', 
        'linear_motion', 'coordinate_system', 'simple_harmonic_motion', 
        'projectile_motion'
    ],
    'ELECTROMAGNETISM': [
        'circuit_diagram', 'charge_distribution', 'magnetic_circuit', 
        'electromagnetic_field', 'capacitance_resistance'
    ],
    'WAVES_OPTICS': [
        'optical_path', 'wave_motion', 'photoelectric_effect', 'acoustics'
    ],
    'THERMAL': ['thermodynamics'],
    'MODERN': [
        'atomic_physics', 'quantum_mechanics', 'relativity_gravity', 
        'feynman_diagram', 'astrophysics'
    ]
}

CATEGORY_EXPERTISE = {
    'static_force_analysis': {
        'en': 'Force equilibrium, vector analysis, torque calculations, structural mechanics, friction, constraint forces',
        'zh': '力平衡，矢量分析，力矩计算，结构力学，摩擦力，约束力'
    },
    'spring_force': {
        'en': 'Hooke\'s law, elastic energy, harmonic oscillations, coupled systems, resonance',
        'zh': '胡克定律，弹性能，谐振荡，耦合系统，共振'
    },
    'circular_motion': {
        'en': 'Centripetal force, angular momentum, rotational dynamics, orbital mechanics',
        'zh': '向心力，角动量，转动动力学，轨道力学'
    },
    'linear_motion': {
        'en': 'Kinematics, dynamics, momentum conservation, collision analysis',
        'zh': '运动学，动力学，动量守恒，碰撞分析'
    },
    'coordinate_system': {
        'en': 'Graph analysis, data interpretation, coordinate transformations, vector fields',
        'zh': '图形分析，数据解释，坐标变换，矢量场'
    },
    'simple_harmonic_motion': {
        'en': 'Oscillation equations, period analysis, energy methods, damping effects',
        'zh': '振动方程，周期分析，能量方法，阻尼效应'
    },
    'projectile_motion': {
        'en': 'Parabolic trajectories, range calculations, angle optimization, air resistance',
        'zh': '抛物轨迹，射程计算，角度优化，空气阻力'
    },
    'circuit_diagram': {
        'en': 'Ohm\'s law, Kirchhoff\'s laws, impedance analysis, AC/DC circuits, network analysis',
        'zh': '欧姆定律，基尔霍夫定律，阻抗分析，交直流电路，网络分析'
    },
    'charge_distribution': {
        'en': 'Electric fields, Gauss\'s law, potential calculations, boundary conditions',
        'zh': '电场，高斯定律，电势计算，边界条件'
    },
    'magnetic_circuit': {
        'en': 'Magnetic flux, inductance, transformer principles, magnetic coupling',
        'zh': '磁通量，电感，变压器原理，磁耦合'
    },
    'electromagnetic_field': {
        'en': 'Maxwell equations, wave propagation, Lorentz force, field interactions',
        'zh': '麦克斯韦方程，波传播，洛伦兹力，场相互作用'
    },
    'capacitance_resistance': {
        'en': 'Dielectric properties, resistance networks, Hall effect, field distributions',
        'zh': '介电性质，电阻网络，霍尔效应，场分布'
    },
    'optical_path': {
        'en': 'Ray tracing, lens systems, interference, diffraction, polarization',
        'zh': '光线追迹，透镜系统，干涉，衍射，偏振'
    },
    'wave_motion': {
        'en': 'Wave equations, superposition, standing waves, Doppler effect',
        'zh': '波动方程，叠加原理，驻波，多普勒效应'
    },
    'photoelectric_effect': {
        'en': 'Einstein equation, photon energy, quantum nature of light, work function',
        'zh': '爱因斯坦方程，光子能量，光的量子性，功函数'
    },
    'acoustics': {
        'en': 'Sound propagation, acoustic resonance, sound intensity, echo analysis',
        'zh': '声传播，声共振，声强，回声分析'
    },
    'thermodynamics': {
        'en': 'Laws of thermodynamics, heat transfer, phase transitions, statistical mechanics',
        'zh': '热力学定律，传热，相变，统计力学'
    },
    'atomic_physics': {
        'en': 'Nuclear structure, radioactive decay, particle interactions, cross-sections',
        'zh': '核结构，放射性衰变，粒子相互作用，截面'
    },
    'quantum_mechanics': {
        'en': 'Schrödinger equation, wave functions, quantum states, uncertainty principle',
        'zh': '薛定谔方程，波函数，量子态，不确定性原理'
    },
    'relativity_gravity': {
        'en': 'Special/general relativity, spacetime, gravitational effects, reference frames',
        'zh': '狭义/广义相对论，时空，引力效应，参考系'
    },
    'feynman_diagram': {
        'en': 'Particle interactions, conservation laws, quantum field theory, decay processes',
        'zh': '粒子相互作用，守恒定律，量子场论，衰变过程'
    },
    'astrophysics': {
        'en': 'Stellar physics, cosmology, orbital mechanics, astronomical observations',
        'zh': '恒星物理，宇宙学，轨道力学，天文观测'
    }
}

SUBJECT_GUIDANCE = {
    'O': {
        'en': 'Focus on basic geometric optics, ray tracing, and fundamental optical phenomena',
        'zh': '专注于基础几何光学、光线追踪和基本光学现象'
    },
    'OPT': {
        'en': 'Apply advanced optics: wave optics, interference, diffraction, quantum optics',
        'zh': '应用高级光学：波动光学、干涉、衍射、量子光学'
    },
    'EM': {
        'en': 'Emphasize electromagnetic fields, circuits, Maxwell equations, wave propagation',
        'zh': '强调电磁场、电路、麦克斯韦方程、波传播'
    },
    'CM': {
        'en': 'Focus on classical mechanics: forces, motion, energy, momentum conservation',
        'zh': '专注于经典力学：力、运动、能量、动量守恒'
    },
    'TSM': {
        'en': 'Apply thermodynamic laws, statistical mechanics, heat transfer principles',
        'zh': '应用热力学定律、统计力学、传热原理'
    },
    'QMIT': {
        'en': 'Use quantum mechanics, wave-particle duality, quantum information theory',
        'zh': '使用量子力学、波粒二象性、量子信息理论'
    },
    'ACG': {
        'en': 'Apply gravitational physics, cosmology, relativity, astronomical principles',
        'zh': '应用引力物理、宇宙学、相对论、天文学原理'
    },
    'AMONP': {
        'en': 'Focus on atomic structure, nuclear physics, particle interactions',
        'zh': '专注于原子结构、核物理、粒子相互作用'
    }
}


In [10]:
# Complete Fixed Rejection Improvement Pipeline
# Prerequisites: MODEL_ID and OPENAI_API_KEY should be defined in your notebook

import json
import os
import asyncio
import base64
import time
from typing import Dict, List, Any, Optional
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from tqdm.notebook import tqdm
import logging
import nest_asyncio
from enum import Enum
from pydantic import BaseModel, Field

# Apply nest_asyncio for Jupyter compatibility
nest_asyncio.apply()

# Physics Domain Definitions
PHYSICS_SUBJECTS = {
    'O': 'Optics (Basic)',
    'OPT': 'Optics (Extended/Advanced)', 
    'EM': 'Electromagnetism',
    'CM': 'Classical Mechanics',
    'TSM': 'Thermodynamics & Statistical Mechanics',
    'QMIT': 'Quantum Mechanics & Information Theory',
    'ACG': 'Astrophysics, Cosmology & Gravitation',
    'AMONP': 'Atomic, Molecular, Optical & Nuclear Physics'
}

SUBJECT_GUIDANCE = {
    'O': {
        'en': 'Focus on basic geometric optics, ray tracing, and fundamental optical phenomena',
        'zh': '专注于基础几何光学、光线追踪和基本光学现象'
    },
    'OPT': {
        'en': 'Apply advanced optics: wave optics, interference, diffraction, quantum optics',
        'zh': '应用高级光学：波动光学、干涉、衍射、量子光学'
    },
    'EM': {
        'en': 'Emphasize electromagnetic fields, circuits, Maxwell equations, wave propagation',
        'zh': '强调电磁场、电路、麦克斯韦方程、波传播'
    },
    'CM': {
        'en': 'Focus on classical mechanics: forces, motion, energy, momentum conservation',
        'zh': '专注于经典力学：力、运动、能量、动量守恒'
    },
    'TSM': {
        'en': 'Apply thermodynamic laws, statistical mechanics, heat transfer principles',
        'zh': '应用热力学定律、统计力学、传热原理'
    },
    'QMIT': {
        'en': 'Use quantum mechanics, wave-particle duality, quantum information theory',
        'zh': '使用量子力学、波粒二象性、量子信息理论'
    },
    'ACG': {
        'en': 'Apply gravitational physics, cosmology, relativity, astronomical principles',
        'zh': '应用引力物理、宇宙学、相对论、天文学原理'
    },
    'AMONP': {
        'en': 'Focus on atomic structure, nuclear physics, particle interactions',
        'zh': '专注于原子结构、核物理、粒子相互作用'
    }
}

CATEGORY_EXPERTISE = {
    'static_force_analysis': {
        'en': 'Force equilibrium, vector analysis, torque calculations, structural mechanics, friction, constraint forces',
        'zh': '力平衡，矢量分析，力矩计算，结构力学，摩擦力，约束力'
    },
    'spring_force': {
        'en': 'Hooke\'s law, elastic energy, harmonic oscillations, coupled systems, resonance',
        'zh': '胡克定律，弹性能，谐振荡，耦合系统，共振'
    },
    'circular_motion': {
        'en': 'Centripetal force, angular momentum, rotational dynamics, orbital mechanics',
        'zh': '向心力，角动量，转动动力学，轨道力学'
    },
    'linear_motion': {
        'en': 'Kinematics, dynamics, momentum conservation, collision analysis',
        'zh': '运动学，动力学，动量守恒，碰撞分析'
    },
    'coordinate_system': {
        'en': 'Graph analysis, data interpretation, coordinate transformations, vector fields',
        'zh': '图形分析，数据解释，坐标变换，矢量场'
    },
    'simple_harmonic_motion': {
        'en': 'Oscillation equations, period analysis, energy methods, damping effects',
        'zh': '振动方程，周期分析，能量方法，阻尼效应'
    },
    'projectile_motion': {
        'en': 'Parabolic trajectories, range calculations, angle optimization, air resistance',
        'zh': '抛物轨迹，射程计算，角度优化，空气阻力'
    },
    'circuit_diagram': {
        'en': 'Ohm\'s law, Kirchhoff\'s laws, impedance analysis, AC/DC circuits, network analysis',
        'zh': '欧姆定律，基尔霍夫定律，阻抗分析，交直流电路，网络分析'
    },
    'charge_distribution': {
        'en': 'Electric fields, Gauss\'s law, potential calculations, boundary conditions',
        'zh': '电场，高斯定律，电势计算，边界条件'
    },
    'magnetic_circuit': {
        'en': 'Magnetic flux, inductance, transformer principles, magnetic coupling',
        'zh': '磁通量，电感，变压器原理，磁耦合'
    },
    'electromagnetic_field': {
        'en': 'Maxwell equations, wave propagation, Lorentz force, field interactions',
        'zh': '麦克斯韦方程，波传播，洛伦兹力，场相互作用'
    },
    'capacitance_resistance': {
        'en': 'Dielectric properties, resistance networks, Hall effect, field distributions',
        'zh': '介电性质，电阻网络，霍尔效应，场分布'
    },
    'optical_path': {
        'en': 'Ray tracing, lens systems, interference, diffraction, polarization',
        'zh': '光线追迹，透镜系统，干涉，衍射，偏振'
    },
    'wave_motion': {
        'en': 'Wave equations, superposition, standing waves, Doppler effect',
        'zh': '波动方程，叠加原理，驻波，多普勒效应'
    },
    'photoelectric_effect': {
        'en': 'Einstein equation, photon energy, quantum nature of light, work function',
        'zh': '爱因斯坦方程，光子能量，光的量子性，功函数'
    },
    'acoustics': {
        'en': 'Sound propagation, acoustic resonance, sound intensity, echo analysis',
        'zh': '声传播，声共振，声强，回声分析'
    },
    'thermodynamics': {
        'en': 'Laws of thermodynamics, heat transfer, phase transitions, statistical mechanics',
        'zh': '热力学定律，传热，相变，统计力学'
    },
    'atomic_physics': {
        'en': 'Nuclear structure, radioactive decay, particle interactions, cross-sections',
        'zh': '核结构，放射性衰变，粒子相互作用，截面'
    },
    'quantum_mechanics': {
        'en': 'Schrödinger equation, wave functions, quantum states, uncertainty principle',
        'zh': '薛定谔方程，波函数，量子态，不确定性原理'
    },
    'relativity_gravity': {
        'en': 'Special/general relativity, spacetime, gravitational effects, reference frames',
        'zh': '狭义/广义相对论，时空，引力效应，参考系'
    },
    'feynman_diagram': {
        'en': 'Particle interactions, conservation laws, quantum field theory, decay processes',
        'zh': '粒子相互作用，守恒定律，量子场论，衰变过程'
    },
    'astrophysics': {
        'en': 'Stellar physics, cosmology, orbital mechanics, astronomical observations',
        'zh': '恒星物理，宇宙学，轨道力学，天文观测'
    }
}

# Evaluation Models
class EvaluationDecision(str, Enum):
    ACCEPT = "accept"
    REJECT = "reject"

class PhysicsErrorAnalysis(BaseModel):
    physics_theorem_errors: List[str] = Field(default_factory=list, description="Wrong physics laws/formulas applied")
    condition_analysis_errors: List[str] = Field(default_factory=list, description="Misidentified forces, boundaries, setup")
    process_understanding_errors: List[str] = Field(default_factory=list, description="Misunderstood physical phenomena")
    calculation_errors: List[str] = Field(default_factory=list, description="Mathematical derivation errors")
    variable_relationship_errors: List[str] = Field(default_factory=list, description="Wrong dependencies between quantities")
    diagram_analysis_errors: List[str] = Field(default_factory=list, description="Misread visual information")
    boundary_conditions_errors: List[str] = Field(default_factory=list, description="Ignored constraints/limits")

class EvaluationResult(BaseModel):
    decision: EvaluationDecision
    confidence_score: float = Field(ge=0, le=1)
    quality_score: float = Field(ge=0, le=1)
    physics_errors: PhysicsErrorAnalysis
    answer_consistency: bool
    magnitude_reasonable: bool
    error_location: Optional[str] = None
    feedback_message: str
    improvement_suggestions: Optional[str] = None

class LabeledExampleManager:
    """Manages labeled examples for few-shot enhancement"""
    
    def __init__(self, labeled_samples_path: str, logger: Optional[logging.Logger] = None):
        self.samples = []
        self.subject_category_index = {}
        self.subject_index = {}
        self.category_index = {}
        self.logger = logger or logging.getLogger(__name__)
        
        if labeled_samples_path and os.path.exists(labeled_samples_path):
            self._load_samples(labeled_samples_path)
            self._build_indices()
    
    def _load_samples(self, file_path: str):
        """Load labeled samples from JSON file"""
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                self.samples = json.load(f)
            self.logger.info(f"✅ Loaded {len(self.samples)} labeled examples")
        except Exception as e:
            self.logger.error(f"⚠️ Failed to load labeled samples: {e}")
            self.samples = []
    
    def _build_indices(self):
        """Build indices for fast lookup"""
        for sample in self.samples:
            subject = sample.get('subject', '')
            category = sample.get('img_category', '')
            
            # Subject + Category index
            key = f"{subject}_{category}"
            if key not in self.subject_category_index:
                self.subject_category_index[key] = []
            self.subject_category_index[key].append(sample)
            
            # Subject-only index
            if subject not in self.subject_index:
                self.subject_index[subject] = []
            self.subject_index[subject].append(sample)
            
            # Category-only index
            if category not in self.category_index:
                self.category_index[category] = []
            self.category_index[category].append(sample)
    
    def _calculate_quality_score(self, sample: Dict[str, Any], target_problem: Dict[str, Any]) -> float:
        """Calculate quality score for a sample"""
        score = 0.0
        
        # Reasoning length/completeness
        reasoning = sample.get('reasoning', '')
        if len(reasoning) > 100:
            score += 0.3
        elif len(reasoning) > 50:
            score += 0.2
        elif len(reasoning) > 20:
            score += 0.1
        
        # Level preference
        sample_level = sample.get('level', 1)
        target_level = target_problem.get('level', 1)
        level_diff = abs(sample_level - target_level)
        if level_diff == 0:
            score += 0.3
        elif level_diff <= 1:
            score += 0.2
        elif level_diff <= 2:
            score += 0.1
        
        # Vision relevance match
        sample_vision = sample.get('vision_relevance', '')
        target_vision = target_problem.get('vision_relevance', '')
        if sample_vision == target_vision:
            score += 0.2
        elif sample_vision in ['necessary', 'helpful'] and target_vision in ['necessary', 'helpful']:
            score += 0.1
        
        # Language match
        if sample.get('language', '') == target_problem.get('language', ''):
            score += 0.1
        
        return score
    
    def find_best_examples(self, problem: Dict[str, Any], max_examples: int = 3) -> List[Dict[str, Any]]:
        """Find best examples based on subject and category matching"""
        if not self.samples:
            return []
        
        subject = problem.get('subject', '')
        category = problem.get('img_category', '')
        
        candidates = []
        
        # Priority 1: Exact subject + category match
        exact_key = f"{subject}_{category}"
        if exact_key in self.subject_category_index:
            candidates.extend(self.subject_category_index[exact_key])
            match_type = "exact"
        
        # Priority 2: Subject match only
        elif subject in self.subject_index:
            candidates.extend(self.subject_index[subject])
            match_type = "subject"
        
        # Priority 3: Category match only  
        elif category in self.category_index:
            candidates.extend(self.category_index[category])
            match_type = "category"
        
        # No matches - skip examples
        else:
            self.logger.debug(f"🔍 No examples found for subject='{subject}', category='{category}'")
            return []
        
        # Calculate quality scores and sort
        scored_candidates = []
        for candidate in candidates:
            score = self._calculate_quality_score(candidate, problem)
            scored_candidates.append((score, candidate))
        
        # Sort by score (descending) and take top examples
        scored_candidates.sort(key=lambda x: x[0], reverse=True)
        selected = [candidate for score, candidate in scored_candidates[:max_examples]]
        
        self.logger.debug(f"🎯 Found {len(selected)} examples ({match_type} match) for {subject}/{category}")
        
        return selected
    
    def format_examples_for_prompt(self, examples: List[Dict[str, Any]], language: str = "English") -> str:
        """Format examples for inclusion in generation prompt"""
        if not examples:
            return ""
        
        lang_code = "zh" if language == "Chinese" else "en"
        
        if lang_code == "zh":
            examples_text = "\n**专家解题示例：**\n\n"
            examples_text += "以下是类似问题的专家解题过程，请参考其推理方法和解答模式：\n\n"
        else:
            examples_text = "\n**EXPERT SOLUTION EXAMPLES:**\n\n"
            examples_text += "Here are expert solutions to similar problems. Use these as guides for reasoning patterns and solution approaches:\n\n"
        
        for i, example in enumerate(examples, 1):
            question = example.get('question', '')[:200] + "..." if len(example.get('question', '')) > 200 else example.get('question', '')
            reasoning = example.get('reasoning', '')
            answer = example.get('answer', '')
            subject = example.get('subject', '')
            category = example.get('img_category', '')
            
            if lang_code == "zh":
                examples_text += f"**示例 {i}** [{subject}/{category}]:\n"
                examples_text += f"问题: {question}\n"
                examples_text += f"专家推理: {reasoning}\n"
                examples_text += f"答案: {answer}\n\n"
            else:
                examples_text += f"**Example {i}** [{subject}/{category}]:\n"
                examples_text += f"Question: {question}\n"
                examples_text += f"Expert Reasoning: {reasoning}\n"
                examples_text += f"Answer: {answer}\n\n"
        
        if lang_code == "zh":
            examples_text += "现在请解决以下问题，参考上述专家的推理模式：\n"
        else:
            examples_text += "Now solve the following problem, following the expert reasoning patterns above:\n"
        
        return examples_text

class RejectionImprovementGenerator:
    """Generator specifically for improving rejected solutions with context injection"""
    
    def __init__(self, api_key: str, model: str, images_base_path: str = "", 
                 labeled_samples_path: str = "", logger: Optional[logging.Logger] = None):
        self.client = ChatOpenAI(api_key=api_key, model=model)
        self.images_base_path = images_base_path
        self.example_manager = LabeledExampleManager(labeled_samples_path, logger) if labeled_samples_path else None
        self.logger = logger or logging.getLogger(__name__)
    
    def _get_subject_context(self, subject_code: str, language: str = "en") -> tuple:
        """Get subject context and guidance"""
        if subject_code and subject_code in PHYSICS_SUBJECTS:
            subject_name = PHYSICS_SUBJECTS[subject_code]
            subject_context = f"{subject_code} ({subject_name})"
            subject_guidance = SUBJECT_GUIDANCE.get(subject_code, {}).get(language, "Apply general physics principles")
        else:
            if language == "zh":
                subject_context = "未指定学科"
                subject_guidance = "应用一般物理原理"
            else:
                subject_context = "Unspecified subject"
                subject_guidance = "Apply general physics principles"
        
        return subject_context, subject_guidance
    
    def _get_category_expertise(self, category: str, language: str = "en") -> str:
        """Get category-specific expertise"""
        lang_key = "zh" if language == "zh" else "en"
        return CATEGORY_EXPERTISE.get(category, {}).get(lang_key, "General physics principles")
    
    def _extract_previous_solution_parts(self, rejected_solution: Dict[str, Any]) -> tuple:
        """Extract think and answer parts from previous prediction"""
        rejected_prediction = rejected_solution.get("prediction", "")
        failed_reasoning = ""
        failed_answer = ""
        
        try:
            # Extract <think> content
            if "<think>" in rejected_prediction and "</think>" in rejected_prediction:
                start = rejected_prediction.find("<think>") + 7
                end = rejected_prediction.find("</think>")
                if end > start:
                    failed_reasoning = rejected_prediction[start:end].strip()
        except Exception as e:
            self.logger.warning(f"Failed to extract reasoning from previous prediction: {e}")
        
        try:
            # Extract <answer> content
            if "<answer>" in rejected_prediction and "</answer>" in rejected_prediction:
                start = rejected_prediction.find("<answer>") + 8
                end = rejected_prediction.find("</answer>")
                if end > start:
                    failed_answer = rejected_prediction[start:end].strip()
        except Exception as e:
            self.logger.warning(f"Failed to extract answer from previous prediction: {e}")
        
        return failed_reasoning, failed_answer
    
    def _format_rejection_feedback(self, rejected_solution: Dict[str, Any], language: str = "English") -> str:
        """Format complete feedback from rejection for improvement guidance"""
        evaluation_feedback = rejected_solution.get("evaluation_feedback", "")
        confidence_score = rejected_solution.get("confidence_score", 0.0)
        quality_score = rejected_solution.get("quality_score", 0.0)
        
        lang_code = "zh" if language == "Chinese" else "en"
        
        if lang_code == "zh":
            feedback = "\n**前次失败分析：**\n\n"
            feedback += "您的前次尝试被评估师拒绝。以下是完整的评估反馈：\n\n"
            feedback += f"**评估师反馈**: {evaluation_feedback}\n\n"
            feedback += f"**评估指标**:\n"
            feedback += f"• 信心分数: {confidence_score:.2f}\n"
            feedback += f"• 质量分数: {quality_score:.2f}\n"
            feedback += f"\n**关键改进指示**: 仔细分析上述评估反馈，在新解答中直接解决所有提到的问题。\n"
        else:
            feedback = "\n**PREVIOUS FAILURE ANALYSIS:**\n\n"
            feedback += "Your previous attempt was rejected by the evaluator. Here is the complete evaluation feedback:\n\n"
            feedback += f"**Evaluator Feedback**: {evaluation_feedback}\n\n"
            feedback += f"**Evaluation Metrics**:\n"
            feedback += f"• Confidence Score: {confidence_score:.2f}\n"
            feedback += f"• Quality Score: {quality_score:.2f}\n"
            feedback += f"\n**CRITICAL IMPROVEMENT INSTRUCTIONS**: Carefully analyze the above evaluation feedback and directly address all issues mentioned in your new solution.\n"
        
        return feedback
    
    def _format_previous_attempts_feedback(self, previous_attempts: List[str], evaluation: EvaluationResult, language: str = "English") -> str:
        """Format previous attempts and evaluation feedback for regeneration"""
        if not previous_attempts or not evaluation:
            return ""
        
        lang_code = "zh" if language == "Chinese" else "en"
        
        if lang_code == "zh":
            feedback = "\n**前次尝试分析：**\n\n"
            feedback += "您的前次尝试被拒绝。以下是问题所在：\n\n"
            
            # Show specific errors
            errors = evaluation.physics_errors
            if errors.physics_theorem_errors:
                feedback += f"❌ **物理定理错误**: {', '.join(errors.physics_theorem_errors)}\n"
            if errors.condition_analysis_errors:
                feedback += f"❌ **条件分析错误**: {', '.join(errors.condition_analysis_errors)}\n"
            if errors.process_understanding_errors:
                feedback += f"❌ **过程理解错误**: {', '.join(errors.process_understanding_errors)}\n"
            if errors.calculation_errors:
                feedback += f"❌ **计算错误**: {', '.join(errors.calculation_errors)}\n"
            if errors.variable_relationship_errors:
                feedback += f"❌ **变量关系错误**: {', '.join(errors.variable_relationship_errors)}\n"
            if errors.diagram_analysis_errors:
                feedback += f"❌ **图表分析错误**: {', '.join(errors.diagram_analysis_errors)}\n"
            if errors.boundary_conditions_errors:
                feedback += f"❌ **边界条件错误**: {', '.join(errors.boundary_conditions_errors)}\n"
            
            # Show evaluation metrics
            feedback += f"\n**评估结果**:\n"
            feedback += f"• 信心分数: {evaluation.confidence_score:.2f}\n"
            feedback += f"• 质量分数: {evaluation.quality_score:.2f}\n"
            feedback += f"• 答案一致性: {'是' if evaluation.answer_consistency else '否'}\n"
            feedback += f"• 数量级合理性: {'是' if evaluation.magnitude_reasonable else '否'}\n"
            
            # Show evaluation feedback
            feedback += f"\n**评估师反馈**: {evaluation.feedback_message}\n"
            
            # Show improvement suggestions
            if evaluation.improvement_suggestions:
                feedback += f"\n**改进建议**: {evaluation.improvement_suggestions}\n"
            
            # Show truncated previous solution
            last_attempt = previous_attempts[-1]
            truncated = last_attempt[:400] + "..." if len(last_attempt) > 400 else last_attempt
            feedback += f"\n**前次解答（已拒绝）**:\n```\n{truncated}\n```\n"
            feedback += "\n**重要指示**: 请在新解答中解决上述问题。避免重复相同错误。\n"
            
        else:
            feedback = "\n**PREVIOUS ATTEMPT ANALYSIS:**\n\n"
            feedback += "Your previous attempt was rejected. Here's what went wrong:\n\n"
            
            # Show specific errors
            errors = evaluation.physics_errors
            if errors.physics_theorem_errors:
                feedback += f"❌ **Physics Theorem Errors**: {', '.join(errors.physics_theorem_errors)}\n"
            if errors.condition_analysis_errors:
                feedback += f"❌ **Condition Analysis Errors**: {', '.join(errors.condition_analysis_errors)}\n"
            if errors.process_understanding_errors:
                feedback += f"❌ **Process Understanding Errors**: {', '.join(errors.process_understanding_errors)}\n"
            if errors.calculation_errors:
                feedback += f"❌ **Calculation Errors**: {', '.join(errors.calculation_errors)}\n"
            if errors.variable_relationship_errors:
                feedback += f"❌ **Variable Relationship Errors**: {', '.join(errors.variable_relationship_errors)}\n"
            if errors.diagram_analysis_errors:
                feedback += f"❌ **Diagram Analysis Errors**: {', '.join(errors.diagram_analysis_errors)}\n"
            if errors.boundary_conditions_errors:
                feedback += f"❌ **Boundary Conditions Errors**: {', '.join(errors.boundary_conditions_errors)}\n"
            
            # Show evaluation metrics
            feedback += f"\n**Evaluation Metrics**:\n"
            feedback += f"• Confidence Score: {evaluation.confidence_score:.2f}\n"
            feedback += f"• Quality Score: {evaluation.quality_score:.2f}\n"
            feedback += f"• Answer Consistency: {'Yes' if evaluation.answer_consistency else 'No'}\n"
            feedback += f"• Magnitude Reasonable: {'Yes' if evaluation.magnitude_reasonable else 'No'}\n"
            
            # Show evaluation feedback
            feedback += f"\n**Evaluator Feedback**: {evaluation.feedback_message}\n"
            
            # Show improvement suggestions
            if evaluation.improvement_suggestions:
                feedback += f"\n**Suggested Improvements**: {evaluation.improvement_suggestions}\n"
            
            # Show truncated previous solution
            last_attempt = previous_attempts[-1]
            truncated = last_attempt[:400] + "..." if len(last_attempt) > 400 else last_attempt
            feedback += f"\n**Previous Solution (REJECTED)**:\n```\n{truncated}\n```\n"
            feedback += "\n**CRITICAL INSTRUCTIONS**: Address the above issues in your new solution. Avoid repeating the same mistakes.\n"
        
        return feedback

    def _create_improvement_prompt(self, rejected_solution: Dict[str, Any], 
                                previous_attempts: Optional[List[str]] = None,
                                evaluation_feedback: Optional[EvaluationResult] = None) -> str:
        """Create comprehensive improvement prompt"""
        
        # Extract ALL relevant fields from the problem
        question = rejected_solution.get("question", "")
        subject = rejected_solution.get("subject", "")
        level = rejected_solution.get("level", 1)
        sig_figs = rejected_solution.get("sig_figs", "")
        caption = rejected_solution.get("caption", "")
        language = rejected_solution.get("language", "English")
        img_category = rejected_solution.get("img_category", "")
        vision_relevance = rejected_solution.get("vision_relevance", "")
        index = rejected_solution.get("index", "Unknown")
        
        lang_code = "zh" if language == "Chinese" else "en"
        
        # Get contexts
        subject_context, subject_guidance = self._get_subject_context(subject, lang_code)
        category_expertise = self._get_category_expertise(img_category, lang_code)
        
        # Get examples if available
        examples_text = ""
        if self.example_manager:
            examples = self.example_manager.find_best_examples(rejected_solution, max_examples=3)
            if examples:
                self.logger.info(f"✅ Injecting {len(examples)} examples for rejected problem {index}")
                examples_text = self.example_manager.format_examples_for_prompt(examples, language)
            else:
                self.logger.info(f"❌ No examples found for rejected problem {index}")
        
        # Get feedback
        feedback_text = ""
        if previous_attempts and evaluation_feedback:
            # This is a regeneration after evaluation
            feedback_text = self._format_previous_attempts_feedback(previous_attempts, evaluation_feedback, language)
        else:
            # This is the first improvement attempt from original rejection
            feedback_text = self._format_rejection_feedback(rejected_solution, language)
        
        # Extract failed reasoning and answer (only for first attempt)
        failed_reasoning, failed_answer = "", ""
        if not previous_attempts:  # Only show original failure on first attempt
            failed_reasoning, failed_answer = self._extract_previous_solution_parts(rejected_solution)
        
        # Prepare conditional sections for failed reasoning and answer
        if failed_reasoning and not previous_attempts:
            if lang_code == "zh":
                failed_reasoning_section = f"""**原始被拒绝的解答内容：**
    **原始推理过程（被拒绝）：**
    ```
    {failed_reasoning}
    ```
    """
            else:
                failed_reasoning_section = f"""**ORIGINAL REJECTED SOLUTION CONTENT:**
    **Original Reasoning Process (REJECTED):**
    ```
    {failed_reasoning}
    ```
    """
        else:
            failed_reasoning_section = ""
        
        if failed_answer and not previous_attempts:
            if lang_code == "zh":
                failed_answer_section = f"""**原始答案（被拒绝）：**
    ```
    {failed_answer}
    ```
    """
            else:
                failed_answer_section = f"""**Original Answer (REJECTED):**
    ```
    {failed_answer}
    ```
    """
        else:
            failed_answer_section = ""
        
        # Create comprehensive prompt
        if lang_code == "zh":
            prompt = f"""你是一位专业的物理学导师，具有深厚的物理学知识。你正在改进一个之前被拒绝的解答。

    **完整问题信息：**
    • 问题编号：{index}
    • 难度等级：{level}/10
    • 物理学科：{subject_context}
    • 问题类别：{img_category}
    • 语言：{language}
    • 视觉相关性：{vision_relevance}
    {f"• 图像说明：{caption}" if caption else "• 图像说明：无"}
    {f"• 有效数字要求：精确到 {sig_figs} 位有效数字" if sig_figs else "• 有效数字要求：未指定"}

    **专业指导：**
    • 学科重点：{subject_guidance}
    • 类别专长：{category_expertise}
    {feedback_text}
    {examples_text}

    **需要重新解决的问题：**
    {question}

    {failed_reasoning_section}{failed_answer_section}

    **改进任务要求：**
    这是一个被评估师拒绝的解答改进任务。请提供完全重新思考的解答：

    1. **仔细分析失败原因**：研究上述评估反馈和失败的解答
    2. **避免相同错误**：不要重复之前的错误推理或计算
    3. **正确应用物理原理**：确保使用适当的{subject}学科知识
    4. **精确数学推导**：所有计算步骤必须准确
    5. **合理最终答案**：检查答案的量纲、数量级和物理意义
    {f"6. **有效数字要求**：最终答案必须精确到{sig_figs}位有效数字" if sig_figs else ""}

    请用以下格式提供改进的解答：
    <think>
    [提供全新的、详细的物理分析和推导过程，完全避免之前的错误]
    </think>

    <answer>
    [提供最终的正确数值答案，符合有效数字要求]
    </answer>"""
        else:
            prompt = f"""You are an expert physics tutor with deep knowledge across all physics domains. You are improving a previously rejected solution.

    **COMPLETE PROBLEM INFORMATION:**
    • Problem Index: {index}
    • Difficulty Level: {level}/10
    • Physics Subject: {subject_context}
    • Problem Category: {img_category}
    • Language: {language}
    • Vision Relevance: {vision_relevance}
    {f"• Image Caption: {caption}" if caption else "• Image Caption: None provided"}
    {f"• Significant Figures: Express answer to exactly {sig_figs} significant figures" if sig_figs else "• Significant Figures: Not specified"}

    **EXPERT GUIDANCE:**
    • Subject Focus: {subject_guidance}
    • Category Expertise: {category_expertise}
    {feedback_text}
    {examples_text}

    **PROBLEM TO RE-SOLVE:**
    {question}

    {failed_reasoning_section}{failed_answer_section}

    **IMPROVEMENT TASK REQUIREMENTS:**
    This is a rejected solution improvement task. Provide a completely re-thought solution:

    1. **Carefully analyze failure reasons**: Study the above evaluation feedback and failed solution
    2. **Avoid the same mistakes**: Do not repeat previous erroneous reasoning or calculations
    3. **Correctly apply physics principles**: Ensure proper use of {subject} subject knowledge
    4. **Accurate mathematical derivations**: All calculation steps must be precise
    5. **Reasonable final answer**: Check answer's dimensions, magnitude, and physical meaning
    {f"6. **Significant figures requirement**: Final answer must be expressed to exactly {sig_figs} significant figures" if sig_figs else ""}

    **CRITICAL**: This solution was previously rejected by an expert evaluator. The feedback above identifies specific issues. Address them systematically and completely.

    Please respond in the following format:
    <think>
    [Provide completely new, detailed physics analysis and derivation process, fully avoiding previous errors]
    </think>

    <answer>
    [Provide final correct numerical answer, meeting significant figures requirements]
    </answer>"""
        
        return prompt    
    
    def _prepare_improvement_content(self, rejected_solution: Dict[str, Any], 
                                   previous_attempts: Optional[List[str]] = None,
                                   evaluation_feedback: Optional[EvaluationResult] = None) -> List[Dict]:
        """Prepare content for improvement including images"""
        prompt = self._create_improvement_prompt(rejected_solution, previous_attempts, evaluation_feedback)
        content = [{"type": "text", "text": prompt}]
        
        # Add images if present
        vision_relevance = rejected_solution.get('vision_relevance', '')
        if vision_relevance in ['necessary', 'helpful', 'optional']:
            image_paths = rejected_solution.get('image_path', [])
            
            for img_path in image_paths:
                full_path = os.path.join(self.images_base_path, img_path) if self.images_base_path else img_path
                
                if os.path.exists(full_path):
                    try:
                        base64_image = self._encode_image_to_base64(full_path)
                        if base64_image:
                            mime_type = self._get_image_mime_type(full_path)
                            content.append({
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:{mime_type};base64,{base64_image}",
                                    "detail": "high"
                                }
                            })
                            self.logger.debug(f"✓ Added image: {img_path}")
                    except Exception as e:
                        self.logger.warning(f"⚠️ Failed to load image {img_path}: {e}")
        
        return content
    
    def _encode_image_to_base64(self, image_path: str) -> Optional[str]:
        """Encode image to base64"""
        try:
            with open(image_path, "rb") as image_file:
                return base64.b64encode(image_file.read()).decode('utf-8')
        except Exception as e:
            self.logger.error(f"Error encoding image {image_path}: {e}")
            return None
    
    def _get_image_mime_type(self, file_path: str) -> str:
        """Get MIME type for image"""
        ext = os.path.splitext(file_path)[1].lower()
        mime_types = {
            '.png': 'image/png',
            '.jpg': 'image/jpeg', 
            '.jpeg': 'image/jpeg',
            '.gif': 'image/gif',
            '.webp': 'image/webp'
        }
        return mime_types.get(ext, 'image/png')
    
    async def generate_solution(self, rejected_solution: Dict[str, Any], 
                              previous_attempts: Optional[List[str]] = None,
                              evaluation_feedback: Optional[EvaluationResult] = None) -> str:
        """Generate improved solution with optional feedback from previous attempts (async)"""
        try:
            content = self._prepare_improvement_content(
                rejected_solution, previous_attempts, evaluation_feedback
            )
            
            # Use ainvoke for async call
            response = await self.client.ainvoke([
                HumanMessage(content=content)
            ])
            
            return response.content
            
        except Exception as e:
            self.logger.error(f"Generation failed: {e}")
            return f"Generation error: {str(e)}"

class AsyncPhysicsEvaluator:
    """Async implementation of physics solution evaluator"""
    
    def __init__(self, api_key: str, model: str, logger: Optional[logging.Logger] = None):
        self.client = ChatOpenAI(api_key=api_key, model=model)
        self.logger = logger or logging.getLogger(__name__)

    def _get_evaluator_system_prompt(self, language: str = "English") -> str:
        """Get system prompt for evaluation"""
        
        if language == "Chinese":
            return """你是一位专业的物理学专家和解答评估师。你的任务是评估从零生成的物理解答质量。

**评估目标：** 决定是否接受生成的解答，或需要重新生成。

**智能接受标准：**
- **绝不接受信心 < 30%** 的解答（必须重新生成）
- **谨慎接受信心 30-50%** 的解答（仅当物理原理正确且无重大错误时）
- **优先接受信心 > 50%** 的解答（即使有小的格式问题）
- **高级问题（L7+）和视觉问题**：信心阈值提高10%

**物理错误分类：**
- 物理定理错误：应用错误的定律/公式
- 条件分析错误：误识别力、边界、系统设置
- 过程理解错误：误解物理现象发展过程
- 计算错误：数学推导中的错误
- 变量关系错误：量之间的错误依赖关系
- 图表分析错误：误读视觉信息
- 边界条件错误：忽略约束/限制

**决策逻辑：**
- 接受：信心 ≥ 50% 且无重大物理错误，或信心 30-50% 且解答基本正确
- 拒绝：信心 < 30%，或有重大物理原理错误，或答案明显不合理

请诚实评估你的信心水平，这直接影响接受/拒绝决定。"""
        
        else:
            return """You are an expert physics specialist and solution evaluator. Your task is to assess the quality of from-scratch generated physics solutions.

**EVALUATION GOAL:** Decide whether to accept the generated solution or regenerate.

**Smart Acceptance Criteria:**
- **NEVER ACCEPT confidence < 30%** solutions (must regenerate)
- **CAUTIOUSLY ACCEPT confidence 30-50%** solutions (only if physics principles correct with no major errors)
- **READILY ACCEPT confidence > 50%** solutions (even with minor formatting issues)
- **Advanced problems (L7+) and vision problems**: Raise confidence thresholds by 10%

**Physics Error Classification:**
- Physics theorem errors: Wrong laws/formulas applied
- Condition analysis errors: Misidentified forces, boundaries, system setup
- Process understanding errors: Misunderstood physical phenomena development
- Calculation errors: Mathematical derivation mistakes
- Variable relationship errors: Wrong dependencies between quantities
- Diagram analysis errors: Misread visual information
- Boundary conditions errors: Ignored constraints/limits

**Decision Logic:**
- REJECT: Confidence < 60% OR any major physics errors
- ACCEPT: Confidence ≥ 60% AND zero major errors

Please honestly assess your confidence level - this directly affects accept/reject decision."""
    
    def _format_evaluation_prompt(self, problem: Dict[str, Any], solution: str) -> str:
        """Format evaluation prompt"""
        language = problem.get("language", "English")
        
        # Get problem details
        question = problem.get("question", "N/A")
        level = problem.get("level", "Unknown")
        category = problem.get("img_category", "Unknown")
        subject_code = problem.get("subject", "")
        caption = problem.get("caption", "")
        sig_figs = problem.get("sig_figs", "")
        
        # Format subject info
        subject_info = ""
        if subject_code and subject_code in PHYSICS_SUBJECTS:
            subject_name = PHYSICS_SUBJECTS[subject_code]
            if language == "Chinese":
                subject_info = f"\n物理学科: {subject_code} ({subject_name})"
            else:
                subject_info = f"\nPhysics Subject: {subject_code} ({subject_name})"
        
        # Format caption info
        caption_info = ""
        if caption:
            if language == "Chinese":
                caption_info = f"\n图像说明: {caption}"
            else:
                caption_info = f"\nImage Caption: {caption}"
        
        if language == "Chinese":
            prompt = f"""**物理问题：**
问题: {question}
难度等级: {level}/10
问题类别: {category}{subject_info}
语言: {language}{caption_info}
有效数字要求: {sig_figs if sig_figs else "未指定"}

**生成的解答：**
{solution}

**评估任务：**
评估这个从零生成的物理解答质量。决定是否接受此解答或需要重新生成。

**考虑因素：**
1. 物理原理应用的正确性（特别是{subject_code}学科原理）
2. 数学推导和计算的准确性
3. 推理过程的逻辑性和完整性
4. 最终答案的合理性（数量级、单位、方向）
5. 如果有图像：视觉信息的正确解释
6. 有效数字要求的遵守（如果指定）
7. 整体解答质量和清晰度

请提供详细的评估结果。"""
        else:
            prompt = f"""**PHYSICS PROBLEM:**
Question: {question}
Level: {level}/10
Category: {category}{subject_info}
Language: {language}{caption_info}
Required Significant Figures: {sig_figs if sig_figs else "Not specified"}

**GENERATED SOLUTION:**
{solution}

**EVALUATION TASK:**
Assess the quality of this from-scratch generated physics solution. Decide whether to accept or regenerate.

**Consider:**
1. Correctness of physics principle applications (especially {subject_code} subject principles)
2. Accuracy of mathematical derivations and calculations
3. Logical consistency and completeness of reasoning
4. Reasonableness of final answer (magnitude, units, direction)
5. If images provided: correct interpretation of visual information
6. Compliance with significant figures requirements (if specified)
7. Overall solution quality and clarity

Please provide detailed evaluation results."""
        
        return prompt
    
    def _prepare_evaluation_content(self, problem: Dict[str, Any], solution: str, images_base_path: str = "") -> List[Dict]:
        """Prepare evaluation content including images"""
        prompt = self._format_evaluation_prompt(problem, solution)
        content = [{"type": "text", "text": prompt}]
        
        # Add images if present
        vision_relevance = problem.get('vision_relevance', '')
        if vision_relevance in ['necessary', 'helpful', 'optional']:
            image_paths = problem.get('image_path', [])
            
            for img_path in image_paths:
                full_path = os.path.join(images_base_path, img_path) if images_base_path else img_path
                
                if os.path.exists(full_path):
                    try:
                        base64_image = self._encode_image_to_base64(full_path)
                        if base64_image:
                            mime_type = self._get_image_mime_type(full_path)
                            content.append({
                                "type": "image_url",
                                "image_url": {
                                    "url": f"data:{mime_type};base64,{base64_image}",
                                    "detail": "high"
                                }
                            })
                    except Exception as e:
                        self.logger.warning(f"⚠️ Failed to load image {img_path}: {e}")
        return content
    
    def _encode_image_to_base64(self, image_path: str) -> Optional[str]:
        """Encode image to base64"""
        try:
            with open(image_path, "rb") as image_file:
                return base64.b64encode(image_file.read()).decode('utf-8')
        except Exception as e:
            return None
    
    def _get_image_mime_type(self, file_path: str) -> str:
        """Get MIME type for image"""
        ext = os.path.splitext(file_path)[1].lower()
        mime_types = {
            '.png': 'image/png',
            '.jpg': 'image/jpeg',
            '.jpeg': 'image/jpeg',
            '.gif': 'image/gif',
            '.webp': 'image/webp'
        }
        return mime_types.get(ext, 'image/png')
    
    async def evaluate_solution(self, problem: Dict[str, Any], solution: str, images_base_path: str = "") -> EvaluationResult:
        """Evaluate physics solution (async)"""
        try:
            content = self._prepare_evaluation_content(problem, solution, images_base_path)
            system_prompt = self._get_evaluator_system_prompt(problem.get("language", "English"))
            
            # Use ainvoke for async call
            response = await self.client.with_structured_output(EvaluationResult).ainvoke([
                SystemMessage(content=system_prompt),
                HumanMessage(content=content)
            ])
            
            # Apply confidence and error-based decision logic
            return self._apply_decision_logic(response, problem)
            
        except Exception as e:
            self.logger.error(f"Evaluation failed: {e}")
            # Default to accept to avoid infinite loops
            return EvaluationResult(
                decision=EvaluationDecision.ACCEPT,
                confidence_score=0.5,
                quality_score=0.5,
                physics_errors=PhysicsErrorAnalysis(),
                answer_consistency=True,
                magnitude_reasonable=True,
                feedback_message=f"Evaluation failed: {e}"
            )
    
    def _apply_decision_logic(self, result: EvaluationResult, problem: Dict[str, Any]) -> EvaluationResult:
        """Apply enhanced decision logic with confidence thresholds"""
        original_decision = result.decision
        confidence = result.confidence_score
        level = problem.get("level", 1)
        vision_relevance = problem.get("vision_relevance", "")
        
        # Count major physics errors
        errors = result.physics_errors
        major_errors = (
            len(errors.physics_theorem_errors) +
            len(errors.condition_analysis_errors) +
            len(errors.process_understanding_errors) +
            len(errors.calculation_errors) +
            len(errors.variable_relationship_errors) +
            len(errors.diagram_analysis_errors)
        )
        
        # Adjust thresholds for advanced/vision problems
        confidence_threshold = 0.6
        if level >= 7 or vision_relevance in ['necessary', 'helpful']:
            confidence_threshold = 0.6
        
        # Collect rejection reasons
        rejection_reasons = []
        
        # Check for rejection conditions
        if confidence < confidence_threshold:
            rejection_reasons.append(f"Low confidence ({confidence:.2f} < {confidence_threshold})")
        
        if major_errors > 0:
            rejection_reasons.append(f"{major_errors} major physics error(s)")
        
        # Make decision based on rejection reasons
        if rejection_reasons:
            final_decision = EvaluationDecision.REJECT
            override_reason = " + ".join(rejection_reasons)
        else:
            final_decision = EvaluationDecision.ACCEPT
            override_reason = f"Good confidence ({confidence:.2f}) + no major errors"
        
        # Update result with final decision
        result.decision = final_decision
        
        if override_reason and original_decision != final_decision:
            result.feedback_message += f" [Override: {override_reason}]"
        elif override_reason:
            result.feedback_message += f" [Confirmed: {override_reason}]"
        
        return result

class RejectionImprovementPipeline:
    """Complete generator-evaluator pipeline for improving rejected solutions"""
    
    def __init__(self, api_key: str, generator_model: str, evaluator_model: str, 
                 max_iterations: int = 3, images_base_path: str = "", 
                 labeled_samples_path: str = "", logger: Optional[logging.Logger] = None):
        
        self.generator = RejectionImprovementGenerator(
            api_key=api_key, 
            model=generator_model,
            images_base_path=images_base_path,
            labeled_samples_path=labeled_samples_path,
            logger=logger
        )
        self.evaluator = AsyncPhysicsEvaluator(
            api_key=api_key, 
            model=evaluator_model,
            logger=logger
        )
        self.max_iterations = max_iterations
        self.images_base_path = images_base_path
        self.logger = logger or logging.getLogger(__name__)
    
    async def _generator_step(self, rejected_solution: Dict[str, Any], 
                            generation_history: List[str],
                            evaluation_result: Optional[EvaluationResult] = None) -> str:
        """Generator step with feedback-enhanced regeneration (async)"""
        try:
            is_regeneration = len(generation_history) > 0
            
            if is_regeneration:
                attempt_num = len(generation_history) + 1
                self.logger.info(f"🔄 Regeneration attempt {attempt_num} with feedback for rejected problem {rejected_solution.get('index', 'Unknown')}")
                previous_attempts = generation_history
                evaluation_feedback = evaluation_result
            else:
                self.logger.info(f"🆕 Initial improvement attempt for rejected problem {rejected_solution.get('index', 'Unknown')}")
                previous_attempts = None
                evaluation_feedback = None
            
            # Generate solution with optional feedback
            solution = await self.generator.generate_solution(
                rejected_solution, 
                previous_attempts=previous_attempts,
                evaluation_feedback=evaluation_feedback
            )
            
            # Enhanced logging for regeneration
            if is_regeneration and evaluation_feedback:
                errors = evaluation_feedback.physics_errors
                error_count = (
                    len(errors.physics_theorem_errors) +
                    len(errors.condition_analysis_errors) +
                    len(errors.process_understanding_errors) +
                    len(errors.calculation_errors) +
                    len(errors.variable_relationship_errors) +
                    len(errors.diagram_analysis_errors)
                )
                self.logger.info(f"✅ Regenerated solution addressing {error_count} previous errors")
                self.logger.info(f"   Previous confidence: {evaluation_feedback.confidence_score:.2f}")
            else:
                self.logger.info(f"✅ Generated improved solution for rejected problem {rejected_solution.get('index', 'Unknown')}")
            
            return solution
            
        except Exception as e:
            self.logger.error(f"❌ Generation failed: {e}")
            return f"Generation error: {str(e)}"
    
    async def _evaluator_step(self, rejected_solution: Dict[str, Any], solution: str) -> EvaluationResult:
        """Evaluator step (async)"""
        if not solution:
            return EvaluationResult(
                decision=EvaluationDecision.REJECT,
                confidence_score=0.0,
                quality_score=0.0,
                physics_errors=PhysicsErrorAnalysis(),
                answer_consistency=False,
                magnitude_reasonable=False,
                feedback_message="No solution provided"
            )
        
        try:
            result = await self.evaluator.evaluate_solution(
                rejected_solution, 
                solution, 
                self.images_base_path
            )
            
            # Count major errors for reporting
            errors = result.physics_errors
            major_errors = (
                len(errors.physics_theorem_errors) +
                len(errors.condition_analysis_errors) +
                len(errors.process_understanding_errors) +
                len(errors.calculation_errors) +
                len(errors.variable_relationship_errors) +
                len(errors.diagram_analysis_errors)
            )
            
            self.logger.info(f"📊 Evaluation: {result.decision.value.upper()} "
                           f"(Confidence: {result.confidence_score:.2f}, "
                           f"Quality: {result.quality_score:.2f}, "
                           f"Errors: {major_errors})")
            
            return result
            
        except Exception as e:
            self.logger.warning(f"⚠️ Evaluation failed: {e}")
            return EvaluationResult(
                decision=EvaluationDecision.ACCEPT,
                confidence_score=0.5,
                quality_score=0.5,
                physics_errors=PhysicsErrorAnalysis(),
                answer_consistency=True,
                magnitude_reasonable=True,
                feedback_message=f"Evaluation failed: {e}"
            )
    
    def _routing_logic(self, generation_history: List[str], 
                     evaluation_result: EvaluationResult,
                     max_iterations: int) -> str:
        """Routing logic after evaluation"""
        attempts = len(generation_history)
        
        if attempts >= max_iterations:
            self.logger.warning(f"⚠️ Max iterations ({max_iterations}) reached")            
            return "max_iterations"
        
        if evaluation_result and evaluation_result.decision == EvaluationDecision.ACCEPT:
            return "accept"
        else:
            return "reject"
    
    async def improve_single_rejection(self, rejected_solution: Dict[str, Any]) -> Dict[str, Any]:
        """Improve a single rejected solution using the complete generator-evaluator loop"""
        
        index = rejected_solution.get('index', 'Unknown')
        
        try:
            # Initialize state
            generation_history = []
            current_solution = None
            evaluation_result = None
            iterations_used = 0
            
            # Simple async workflow loop
            while True:
                # Generation step
                current_solution = await self._generator_step(
                    rejected_solution, generation_history, evaluation_result
                )
                generation_history.append(current_solution)
                iterations_used += 1
                
                # Evaluation step
                evaluation_result = await self._evaluator_step(rejected_solution, current_solution)
                
                # Routing logic
                decision = self._routing_logic(generation_history, evaluation_result, self.max_iterations)
                
                if decision in ["accept", "max_iterations"]:
                    final_solution = current_solution
                    break
                elif decision == "reject":
                    # Continue loop for regeneration
                    continue
            
            # Create updated solution
            updated_solution = rejected_solution.copy()
            updated_solution['prediction'] = final_solution
            
            # Add improvement metadata
            updated_solution['improvement_status'] = 'improved'
            updated_solution['original_prediction'] = rejected_solution['prediction']
            updated_solution['improvement_timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S')
            updated_solution['improvement_method'] = 'rejection_feedback_enhancement_with_evaluation_loop'
            
            # Update evaluation fields with final results
            updated_solution['final_decision'] = evaluation_result.decision.value
            updated_solution['confidence_score'] = evaluation_result.confidence_score
            updated_solution['quality_score'] = evaluation_result.quality_score
            updated_solution['evaluation_feedback'] = evaluation_result.feedback_message
            
            # Update generation tracking
            if 'generation_history' not in updated_solution:
                updated_solution['generation_history'] = []
            updated_solution['generation_history'].extend(generation_history)
            updated_solution['total_generations'] = len(updated_solution['generation_history'])
            updated_solution['iterations_used'] = updated_solution.get('iterations_used', 0) + iterations_used
            
            # Add improvement-specific tracking
            updated_solution['improvement_iterations'] = iterations_used
            updated_solution['improvement_generations'] = len(generation_history)
            
            self.logger.info(f"✅ Completed improvement for problem {index}: "
                           f"{evaluation_result.decision.value} after {iterations_used} iterations")
            
            return updated_solution
            
        except Exception as e:
            self.logger.error(f"❌ Failed to improve solution {index}: {e}")
            # Return original with error flag
            error_solution = rejected_solution.copy()
            error_solution['improvement_status'] = 'failed'
            error_solution['improvement_error'] = str(e)
            error_solution['improvement_timestamp'] = time.strftime('%Y-%m-%d %H:%M:%S')
            return error_solution
    
    async def improve_all_rejections(self, results_file_path: str, 
                                   output_file_path: Optional[str] = None,
                                   batch_size: int = 5) -> List[Dict[str, Any]]:
        """Improve all rejected solutions in the results file"""
        
        # Load existing results
        with open(results_file_path, 'r', encoding='utf-8') as f:
            all_results = json.load(f)
        
        self.logger.info(f"📁 Loaded {len(all_results)} total results from {results_file_path}")
        
        # Find rejected solutions that haven't been improved yet
        rejected_solutions = []
        for result in all_results:
            if (result.get('final_decision') == 'reject' and 
                result.get('improvement_status') != 'improved'):
                rejected_solutions.append(result)
        
        self.logger.info(f"🎯 Found {len(rejected_solutions)} rejected solutions to improve")
        
        if len(rejected_solutions) == 0:
            self.logger.info("✅ No new rejected solutions found to improve!")
            return all_results
        
        # Create batches for async processing
        batches = [rejected_solutions[i:i + batch_size] for i in range(0, len(rejected_solutions), batch_size)]
        
        # Process batches with progress tracking
        improved_solutions = []
        
        progress_bar = tqdm(batches, desc="Improving Rejections")
        
        for batch_num, batch in enumerate(progress_bar, 1):
            self.logger.info(f"🚀 Processing batch {batch_num}/{len(batches)} with {len(batch)} rejections")
            
            # Process batch asynchronously
            tasks = [self.improve_single_rejection(solution) for solution in batch]
            
            try:
                batch_results = await asyncio.wait_for(
                    asyncio.gather(*tasks, return_exceptions=True),
                    timeout=12000  # 200 minute timeout per batch
                )
            except asyncio.TimeoutError:
                self.logger.error(f"❌ Batch {batch_num} timed out")
                # Create timeout results
                batch_results = []
                for solution in batch:
                    timeout_result = solution.copy()
                    timeout_result['improvement_status'] = 'timeout'
                    timeout_result['improvement_error'] = 'Batch processing timed out'
                    batch_results.append(timeout_result)
            
            # Handle exceptions
            for result in batch_results:
                if isinstance(result, Exception):
                    self.logger.error(f"Batch processing error: {result}")
                else:
                    improved_solutions.append(result)
            
            # Update progress
            successful_improvements = len([s for s in improved_solutions if s.get('improvement_status') == 'improved'])
            accepted_improvements = len([s for s in improved_solutions if s.get('final_decision') == 'accept'])
            
            progress_bar.set_postfix({
                'Improved': successful_improvements,
                'Accepted': accepted_improvements,
                'Remaining': len(rejected_solutions) - len(improved_solutions)
            })
            
            # Small delay between batches to prevent rate limiting
            await asyncio.sleep(5)
        
        progress_bar.close()
        
        # Update the full results list
        improved_lookup = {sol.get('index'): sol for sol in improved_solutions}
        
        updated_results = []
        for result in all_results:
            index = result.get('index')
            if index in improved_lookup:
                updated_results.append(improved_lookup[index])
                self.logger.debug(f"✓ Updated solution for index {index}")
            else:
                updated_results.append(result)
        
        # Save updated results
        output_path = output_file_path or results_file_path.replace('.json', '_rejection_improved.json')
        
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(updated_results, f, indent=2, ensure_ascii=False)
        
        # Analytics
        successful_improvements = len([s for s in improved_solutions if s.get('improvement_status') == 'improved'])
        failed_improvements = len([s for s in improved_solutions if s.get('improvement_status') == 'failed'])
        accepted_improvements = len([s for s in improved_solutions if s.get('final_decision') == 'accept'])
        total_iterations = sum([s.get('improvement_iterations', 0) for s in improved_solutions])
        total_generations = sum([s.get('improvement_generations', 0) for s in improved_solutions])
        
        self.logger.info(f"💾 Saved {len(updated_results)} updated results to {output_path}")
        self.logger.info(f"✅ Rejection Improvement Summary:")
        self.logger.info(f"   - Total solutions: {len(all_results)}")
        self.logger.info(f"   - Rejected solutions found: {len(rejected_solutions)}")
        self.logger.info(f"   - Successfully improved: {successful_improvements}")
        self.logger.info(f"   - Failed improvements: {failed_improvements}")
        self.logger.info(f"   - Final accepted: {accepted_improvements}")
        self.logger.info(f"   - Success rate: {successful_improvements/len(rejected_solutions)*100:.1f}%")
        self.logger.info(f"   - Acceptance rate: {accepted_improvements/len(rejected_solutions)*100:.1f}%")
        self.logger.info(f"   - Average iterations per problem: {total_iterations/len(rejected_solutions):.1f}")
        self.logger.info(f"   - Average generations per problem: {total_generations/len(rejected_solutions):.1f}")
        
        return updated_results

def setup_logging(log_file: str = "rejection_improvement.log", verbose: bool = True):
    """Setup logging configuration"""
    os.makedirs(os.path.dirname(log_file) if os.path.dirname(log_file) else '.', exist_ok=True)
    
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(log_file, encoding='utf-8'),
            logging.StreamHandler() if verbose else logging.NullHandler()
        ],
        force=True
    )
    return logging.getLogger(__name__)

def run_rejection_improvement_pipeline(results_file_path: str,
                                     api_key: Optional[str] = None,
                                     images_base_path: str = "",
                                     labeled_samples_path: str = "",
                                     output_file_path: Optional[str] = None,
                                     generator_model: str = None,
                                     evaluator_model: str = None,
                                     max_iterations: int = 3,
                                     batch_size: int = 5,
                                     log_file: str = "rejection_improvement.log",
                                     verbose: bool = True):
    """
    Main function to run the complete rejection improvement pipeline
    
    Args:
        results_file_path: Path to the JSON file with results
        api_key: OpenAI API key (if None, will use environment variable)
        images_base_path: Base path for images
        labeled_samples_path: Path to labeled examples JSON for few-shot learning
        output_file_path: Output file path (if None, adds '_rejection_improved' suffix)
        generator_model: Model to use for generation (must be provided)
        evaluator_model: Model to use for evaluation (must be provided)
        max_iterations: Max regeneration attempts per rejection
        batch_size: Number of solutions to process in parallel
        log_file: Log file path
        verbose: Whether to log to console
    """
    
    # Setup logging
    logger = setup_logging(log_file, verbose)
    
    # Get API key
    if api_key is None:
        api_key = os.environ.get("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("OPENAI_API_KEY environment variable not set")
    
    # Validate model parameters
    if generator_model is None:
        raise ValueError("generator_model must be provided")
    if evaluator_model is None:
        raise ValueError("evaluator_model must be provided")
    
    # Initialize complete pipeline
    pipeline = RejectionImprovementPipeline(
        api_key=api_key,
        generator_model=generator_model,
        evaluator_model=evaluator_model,
        max_iterations=max_iterations,
        images_base_path=images_base_path,
        labeled_samples_path=labeled_samples_path,
        logger=logger
    )
    
    # Log configuration
    logger.info(f"🚀 Starting Complete Rejection Improvement Pipeline")
    logger.info(f"   Generator Model: {generator_model}")
    logger.info(f"   Evaluator Model: {evaluator_model}")
    logger.info(f"   Max Iterations: {max_iterations}")
    logger.info(f"   Images Base Path: {images_base_path}")
    logger.info(f"   Labeled Samples: {labeled_samples_path}")
    logger.info(f"   Batch Size: {batch_size}")
    
    # Run improvement with complete generator-evaluator loop
    try:
        # Run the async pipeline
        improved_results = asyncio.run(
            pipeline.improve_all_rejections(
                results_file_path=results_file_path,
                output_file_path=output_file_path,
                batch_size=batch_size
            )
        )
        
        logger.info("🎉 Complete rejection improvement pipeline completed successfully!")
        logger.info(f"📝 Detailed logs saved to: {log_file}")
        return improved_results
        
    except Exception as e:
        logger.error(f"❌ Pipeline failed: {e}")
        raise



In [11]:
# Example usage - Uses your notebook's MODEL_ID
if __name__ == "__main__":
    # Uses the MODEL_ID already defined in your notebook
    improved_results = run_rejection_improvement_pipeline(
        results_file_path="total_physics_results_o3.json",
        images_base_path="///mnt/c/Personal/Competitions/ICML_Track2/input/starting_kit_latest/",
        labeled_samples_path="///mnt/c/Personal/Competitions/ICML_Track2/input/starting_kit_latest/dev.json",
        output_file_path="total_physics_results_o3_rejection_improved.json",
        generator_model=MODEL_ID,  # Uses your notebook's MODEL_ID
        evaluator_model=MODEL_ID,  # Uses your notebook's MODEL_ID  
        max_iterations=5,
        batch_size=25,
        verbose=True
    )

2025-06-15 15:54:15,961 - INFO - ✅ Loaded 200 labeled examples
2025-06-15 15:54:15,992 - INFO - 🚀 Starting Complete Rejection Improvement Pipeline
2025-06-15 15:54:15,993 - INFO -    Generator Model: o3-2025-04-16
2025-06-15 15:54:15,994 - INFO -    Evaluator Model: o3-2025-04-16
2025-06-15 15:54:15,994 - INFO -    Max Iterations: 5
2025-06-15 15:54:15,995 - INFO -    Images Base Path: ///mnt/c/Personal/Competitions/ICML_Track2/input/starting_kit_latest/
2025-06-15 15:54:15,996 - INFO -    Labeled Samples: ///mnt/c/Personal/Competitions/ICML_Track2/input/starting_kit_latest/dev.json
2025-06-15 15:54:15,996 - INFO -    Batch Size: 25
2025-06-15 15:54:16,175 - INFO - 📁 Loaded 2000 total results from total_physics_results_o3.json
2025-06-15 15:54:16,176 - INFO - 🎯 Found 105 rejected solutions to improve


Improving Rejections:   0%|          | 0/5 [00:00<?, ?it/s]

2025-06-15 15:54:16,185 - INFO - 🚀 Processing batch 1/5 with 25 rejections
2025-06-15 15:54:16,186 - INFO - 🆕 Initial improvement attempt for rejected problem 70
2025-06-15 15:54:16,186 - INFO - ✅ Injecting 3 examples for rejected problem 70
2025-06-15 15:54:16,188 - INFO - 🆕 Initial improvement attempt for rejected problem 87
2025-06-15 15:54:16,190 - INFO - ✅ Injecting 3 examples for rejected problem 87
2025-06-15 15:54:16,190 - INFO - 🆕 Initial improvement attempt for rejected problem 88
2025-06-15 15:54:16,191 - INFO - ✅ Injecting 3 examples for rejected problem 88
2025-06-15 15:54:16,192 - INFO - 🆕 Initial improvement attempt for rejected problem 152
2025-06-15 15:54:16,192 - INFO - ✅ Injecting 3 examples for rejected problem 152
2025-06-15 15:54:16,194 - INFO - 🆕 Initial improvement attempt for rejected problem 170
2025-06-15 15:54:16,195 - INFO - ✅ Injecting 1 examples for rejected problem 170
2025-06-15 15:54:16,196 - INFO - 🆕 Initial improvement attempt for rejected problem 18