In [None]:
import os
import sys

os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import warnings
warnings.filterwarnings('ignore')

import pandas as pd
import polars as pl
import re
from typing import Optional
from collections import Counter
import sympy as sp
from sympy import symbols, solve
import kaggle_evaluation.aimo_3_inference_server

In [None]:
def extract_answer(text: str) -> Optional[int]:
    patterns = [
        r'\\boxed\{(\d{1,10})\}',
        r'(?:answer|result|solution)(?:\s+is)?:?\s*(\d{1,10})',
        r'(?:final answer|the answer).*?(\d{1,10})',
        r'=\s*(\d{1,10})(?:\s|$|\.|,)',
        r'(\d{1,10})(?:\s+(?:is the|as the) answer)',
        r'therefore.*?(\d{1,10})',
    ]
    
    for pattern in patterns:
        matches = re.findall(pattern, text, re.IGNORECASE | re.DOTALL)
        if matches:
            try:
                answer = int(matches[-1])
                if 0 <= answer <= 99999:
                    return answer
            except ValueError:
                continue
    return None

def parse_problem(problem_text: str) -> dict:
    result = {'text': problem_text, 'modulo': None}
    modulo_patterns = [
        r'remainder when.*divided by\s*(\d+)',
        r'modulo\s*(\d+)',
        r'mod\s*(\d+)',
        r'\(mod\s*(\d+)\)',
    ]
    for pattern in modulo_patterns:
        match = re.search(pattern, problem_text, re.IGNORECASE)
        if match:
            result['modulo'] = int(match.group(1))
            break
    return result

def format_answer(answer: int, modulo: Optional[int] = None) -> int:
    if modulo:
        answer = answer % modulo
    return abs(answer) % 100000

def validate_answer(answer: int) -> bool:
    return isinstance(answer, int) and 0 <= answer <= 99999

In [None]:
class MathTools:
    def solve_equation(self, equation_str: str, variable: str = 'x') -> list:
        try:
            x = symbols(variable)
            if '=' in equation_str:
                left, right = equation_str.split('=')
                eq = sp.sympify(left) - sp.sympify(right)
            else:
                eq = sp.sympify(equation_str)
            return solve(eq, x)
        except:
            return []
    
    def is_prime(self, n: int) -> bool:
        return sp.isprime(n)
    
    def prime_factors(self, n: int) -> dict:
        return sp.factorint(n)
    
    def divisors(self, n: int) -> list:
        return sp.divisors(n)

math_tools = MathTools()

In [None]:
class MultiStrategyLLMSolver:
    """Ensemble solver using multiple strategies with same model."""
    
    def __init__(self, num_strategies=5):
        self.tools = math_tools
        self.num_strategies = num_strategies
        self.llm_tokenizer = None
        self.llm_model = None
        
        # Define multiple solving strategies
        self.strategies = [
            {'name': 'step_by_step', 'temp': 0.6, 'top_p': 0.9, 'top_k': 40},
            {'name': 'creative', 'temp': 0.9, 'top_p': 0.95, 'top_k': 50},
            {'name': 'focused', 'temp': 0.5, 'top_p': 0.85, 'top_k': 30},
            {'name': 'exploratory', 'temp': 1.0, 'top_p': 0.95, 'top_k': 60},
            {'name': 'balanced', 'temp': 0.7, 'top_p': 0.9, 'top_k': 45},
        ]
        
        self.prompt_templates = [
            """Solve this IMO problem step-by-step.

Problem: {problem}

Approach:
1. Understand what's being asked
2. Identify key concepts (algebra, number theory, combinatorics, geometry)
3. Work through systematically
4. Verify your answer
5. Give final answer as integer 0-99999 in \\boxed{{}}

Solution:""",
            """You are solving an olympiad math problem.

Problem: {problem}

Think carefully and solve step-by-step. Express your final numerical answer (0-99999) clearly using \\boxed{{answer}}.""",
            """Analyze and solve:

{problem}

Break this down:
- What type of problem is this?
- What techniques apply?
- What's the solution path?

Final answer (0-99999) in \\boxed{{}}:""",
        ]
    
    def _init_llm(self):
        if self.llm_model is not None:
            return
        
        try:
            import torch
            from transformers import AutoTokenizer, AutoModelForCausalLM
            
            if os.path.exists('/kaggle/input/qwen2-5-math-1-5b-instruct/qwen_math'):
                model_path = '/kaggle/input/qwen2-5-math-1-5b-instruct/qwen_math'
            else:
                model_path = "Qwen/Qwen2.5-Math-1.5B-Instruct"
            
            print(f"Loading LLM: {model_path}")
            
            self.llm_tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
            self.llm_model = AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                device_map="auto" if torch.cuda.is_available() else None,
            )
            
            if torch.cuda.is_available():
                self.llm_model.eval()
            
            print(f"LLM ready on {'GPU' if torch.cuda.is_available() else 'CPU'}")
        except Exception as e:
            print(f"LLM load failed: {e}")
    
    def solve_symbolic(self, problem: str) -> Optional[int]:
        parsed = parse_problem(problem)
        if 'solve' in problem.lower() and 'for' in problem.lower():
            eq_match = re.search(r'solve\s+(.+?)\s+for', problem, re.IGNORECASE)
            if eq_match:
                equation = eq_match.group(1).strip()
                try:
                    solutions = self.tools.solve_equation(equation)
                    if solutions:
                        answer = int(solutions[0])
                        return format_answer(answer, parsed['modulo'])
                except:
                    pass
        return None
    
    def solve_with_strategy(self, problem: str, strategy: dict, prompt_idx: int) -> Optional[int]:
        self._init_llm()
        if self.llm_model is None:
            return None
        
        try:
            import torch
            
            prompt = self.prompt_templates[prompt_idx % len(self.prompt_templates)].format(problem=problem)
            
            messages = [
                {"role": "system", "content": "You are an expert mathematician specializing in olympiad problems."},
                {"role": "user", "content": prompt}
            ]

            text = self.llm_tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )

            inputs = self.llm_tokenizer([text], return_tensors="pt")
            if hasattr(self.llm_model, 'device'):
                inputs = inputs.to(self.llm_model.device)

            with torch.no_grad():
                outputs = self.llm_model.generate(
                    **inputs,
                    max_new_tokens=2048,
                    temperature=strategy['temp'],
                    do_sample=True,
                    top_p=strategy['top_p'],
                    top_k=strategy['top_k'],
                    pad_token_id=self.llm_tokenizer.eos_token_id,
                )

            response = self.llm_tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
            return extract_answer(response)
        except Exception as e:
            print(f"Strategy {strategy['name']} error: {e}")
            return None
    
    def solve_ensemble(self, problem: str) -> Optional[int]:
        """Run multiple strategies and vote."""
        answers = []
        
        for i, strategy in enumerate(self.strategies[:self.num_strategies]):
            answer = self.solve_with_strategy(problem, strategy, i)
            if answer is not None:
                answers.append(answer)
        
        if not answers:
            return None
        
        # Weighted voting: give more weight to consensus
        vote_counts = Counter(answers)
        best_answer, count = vote_counts.most_common(1)[0]
        
        if len(answers) > 1:
            print(f"Strategies: {answers} → {best_answer} ({count}/{len(answers)} votes)")
        
        return best_answer
    
    def solve_problem(self, problem_id: str, problem_text: str) -> int:
        # Try symbolic first
        answer = self.solve_symbolic(problem_text)
        
        # Fall back to ensemble
        if answer is None:
            answer = self.solve_ensemble(problem_text)
        
        # Default to 0
        if answer is None:
            answer = 0
        
        if not validate_answer(answer):
            answer = abs(answer) % 100000
        
        return answer

solver = MultiStrategyLLMSolver(num_strategies=5)

In [None]:
def predict(id_: pl.Series, problem: pl.Series) -> pl.DataFrame:
    try:
        question_id = id_.item(0)
        question_text = problem.item(0)
        answer = solver.solve_problem(question_id, question_text)
        return pl.DataFrame({"id": [question_id], "answer": [answer]})
    except Exception as e:
        print(f"Prediction error: {e}")
        return pl.DataFrame({"id": [id_.item(0)], "answer": [0]})

inference_server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)

# Test on reference.csv (real IMO problems) for evaluation
print("=" * 60)
print("TESTING ON REFERENCE.CSV (IMO-LEVEL PROBLEMS)")
print("=" * 60)

reference_path = '/kaggle/input/ai-mathematical-olympiad-progress-prize-3/reference.csv'
if not os.path.exists(reference_path):
    reference_path = '../data/reference.csv'

if os.path.exists(reference_path):
    try:
        ref_df = pd.read_csv(reference_path)
        correct = 0
        total = len(ref_df)

        for idx, row in ref_df.iterrows():
            problem_id = row['id']
            problem_text = row['problem']
            expected_answer = int(row['answer'])

            print(f"\n[{idx+1}/{total}] Problem {problem_id}")
            print(f"Expected: {expected_answer}")

            predicted_answer = solver.solve_problem(problem_id, problem_text)
            print(f"Predicted: {predicted_answer}")

            if predicted_answer == expected_answer:
                print("✓ CORRECT")
                correct += 1
            else:
                print("✗ INCORRECT")

        accuracy = (correct / total * 100) if total > 0 else 0
        print("\n" + "=" * 60)
        print(f"REFERENCE.CSV SCORE: {correct}/{total} ({accuracy:.1f}%)")
        print("=" * 60 + "\n")
    except Exception as e:
        print(f"Error testing reference.csv: {e}")
else:
    print("reference.csv not found, skipping evaluation\n")

# Run inference server or local gateway
if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    print("Running local gateway on test.csv...")
    inference_server.run_local_gateway(
        ('/kaggle/input/ai-mathematical-olympiad-progress-prize-3/test.csv',)
    )