In [None]:
# TIR-Enabled Ensemble with Self-Consistency
# Based on AIMO Progress Prize 1 & 2 winning approaches

import os
import sys

# Suppress common warnings and CUDA messages
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['CUDA_LAUNCH_BLOCKING'] = '0'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'  # Avoid tokenizer warnings in parallel

import warnings
warnings.filterwarnings('ignore')

# Suppress TensorFlow/CUDA initialization warnings
import logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)
logging.getLogger('absl').setLevel(logging.ERROR)

# Fix for protobuf MessageFactory.GetPrototype compatibility issue
import google.protobuf.message_factory as message_factory
from google.protobuf import descriptor_pool

def GetPrototype(self, descriptor):
    return self.GetMessageClass(descriptor)

if not hasattr(message_factory.MessageFactory, 'GetPrototype'):
    message_factory.MessageFactory.GetPrototype = GetPrototype

import pandas as pd
import polars as pl
import re
import subprocess
import tempfile
from typing import Optional, List, Tuple
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
import kaggle_evaluation.aimo_3_inference_server

In [None]:
import subprocess
import tempfile
import os

def execute_python_code(code: str, timeout: int = 10) -> tuple[Optional[int], str]:
    """Safely execute Python code and return the result."""
    try:
        # Add common math libraries automatically
        enhanced_code = """import math
import numpy as np
try:
    import sympy as sp
    from sympy import *
except ImportError:
    pass

""" + code
        
        # Ensure last line is printed if it's not already
        lines = enhanced_code.strip().split('\n')
        if lines and 'print(' not in lines[-1] and 'import' not in lines[-1]:
            # Remove trailing comments
            if '#' in lines[-1]:
                lines[-1] = lines[-1].split('#')[0]
            lines[-1] = 'print(' + lines[-1].strip() + ')'
        enhanced_code = '\n'.join(lines)
        
        # Create a safe execution environment
        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
            f.write(enhanced_code)
            temp_file = f.name
        
        try:
            result = subprocess.run(
                ['python3', temp_file],
                capture_output=True,
                text=True,
                timeout=timeout
            )
            
            output = result.stdout + result.stderr
            
            # Try to extract the last number printed
            lines = output.strip().split('\n')
            for line in reversed(lines):
                # Look for numbers in the output
                numbers = re.findall(r'\b(\d{1,5})\b', line)
                if numbers:
                    try:
                        answer = int(numbers[-1])
                        if 0 <= answer <= 99999:
                            return answer, output
                    except:
                        pass
            
            return None, output
        finally:
            os.unlink(temp_file)
    except Exception as e:
        return None, str(e)

def extract_code_from_response(text: str) -> Optional[str]:
    """Extract Python code blocks from model response."""
    # Look for code blocks
    code_patterns = [
        r'```python\n(.*?)```',
        r'```\n(.*?)```',
    ]
    
    for pattern in code_patterns:
        matches = re.findall(pattern, text, re.DOTALL)
        if matches:
            return matches[-1].strip()
    
    return None

def extract_answer(text: str) -> Optional[int]:
    """Extract numerical answer from model output (enhanced patterns)."""
    patterns = [
        r'\\boxed\{(\d{1,5})\}',
        r'oxed\{(\d{1,5})\}',  # Sometimes \\ is dropped
        r'#### (\d{1,5})',
        r'(?i)final\s+answer\s*(?:is|:)?\s*(\d{1,5})',
        r'(?:answer|result|solution)(?:\s+is)?:?\s*(\d{1,5})',
        r'=\s*(\d{1,5})(?:\s|$|\.|,)',
        r'(\d{1,5})(?:\s+(?:is the|as the) answer)',
        r'therefore.*?(\d{1,5})',
        r'answer\s*=\s*(\d{1,5})',  # Variable assignment
    ]
    
    for pattern in patterns:
        matches = re.findall(pattern, text, re.IGNORECASE | re.DOTALL)
        if matches:
            try:
                answer = int(matches[-1])
                if 0 <= answer <= 99999:
                    return answer
            except ValueError:
                continue
    
    # Try to find any 1-5 digit number near end of text
    last_500 = text[-500:] if len(text) > 500 else text
    numbers = re.findall(r'\b(\d{1,5})\b', last_500)
    if numbers:
        try:
            answer = int(numbers[-1])
            if 0 <= answer <= 99999:
                return answer
        except:
            pass
    
    return None

def validate_answer(answer: int) -> bool:
    return isinstance(answer, int) and 0 <= answer <= 99999

In [None]:
# Removed MathTools and symbolic solver to comply with competition rules
# Competition requires model-based solutions only

In [None]:
import os
import glob
import threading

def find_model_path(base_path):
    """Auto-detect the actual model directory within a dataset."""
    if not os.path.exists(base_path):
        return None
    
    if os.path.exists(os.path.join(base_path, 'config.json')):
        return base_path
    
    for root, dirs, files in os.walk(base_path):
        if 'config.json' in files:
            return root
    
    return None


class ModelConfig:
    """Configuration for each model strategy."""
    
    def __init__(self, name: str, kaggle_path: str, hf_path: str, temp: float = 0.7, 
                 top_p: float = 0.9, top_k: int = 50, needs_trust: bool = False, 
                 num_samples: int = 1):
        self.name = name
        self.kaggle_path = kaggle_path
        self.hf_path = hf_path
        self.temp = temp
        self.top_p = top_p
        self.top_k = top_k
        self.needs_trust = needs_trust
        self.num_samples = num_samples
    
    def get_path(self):
        if self.kaggle_path:
            actual_path = find_model_path(self.kaggle_path)
            if actual_path:
                return actual_path
        return self.hf_path


class EnsembleSolver:
    """TIR-enabled solver with parallel sampling and early stopping."""
    
    def __init__(self, early_stop_threshold: int = 10):
        self.early_stop_threshold = early_stop_threshold  # Increased from 5 to 10
        
        # Increase samples per model for more diversity
        self.model_configs = [
            # Model 1: Qwen2.5-Math-1.5B (fast, MORE samples)
            ModelConfig(
                "Qwen-1.5B",
                '/kaggle/input/qwen2-5-math-1-5b-instruct',
                'Qwen/Qwen2.5-Math-1.5B-Instruct',
                temp=0.7, top_p=0.9, top_k=50, num_samples=5  # Increased from 3
            ),
            
            # Model 2: Qwen2.5-Math-7B - Higher temps for more exploration
            ModelConfig(
                "Qwen-7B-Low",
                '/kaggle/input/qwen2-5-math-7b-instruct',
                'Qwen/Qwen2.5-Math-7B-Instruct',
                temp=0.6, top_p=0.9, top_k=40, num_samples=3  # Increased from 2
            ),
            ModelConfig(
                "Qwen-7B-Mid",
                '/kaggle/input/qwen2-5-math-7b-instruct',
                'Qwen/Qwen2.5-Math-7B-Instruct',
                temp=0.8, top_p=0.92, top_k=50, num_samples=3  # Increased temp & samples
            ),
            ModelConfig(
                "Qwen-7B-High",
                '/kaggle/input/qwen2-5-math-7b-instruct',
                'Qwen/Qwen2.5-Math-7B-Instruct',
                temp=1.0, top_p=0.95, top_k=60, num_samples=4  # Much higher temp for creativity
            ),
            
            # Model 3: DeepSeek-Math-7B-RL - More samples
            ModelConfig(
                "DeepSeek-Low",
                '/kaggle/input/deepseek-math-7b-rl',
                'deepseek-ai/deepseek-math-7b-rl',
                temp=0.5, top_p=0.88, top_k=35, num_samples=3,  # Increased
                needs_trust=True
            ),
            ModelConfig(
                "DeepSeek-Mid",
                '/kaggle/input/deepseek-math-7b-rl',
                'deepseek-ai/deepseek-math-7b-rl',
                temp=0.7, top_p=0.92, top_k=50, num_samples=3,  # Increased
                needs_trust=True
            ),
            ModelConfig(
                "DeepSeek-High",
                '/kaggle/input/deepseek-math-7b-rl',
                'deepseek-ai/deepseek-math-7b-rl',
                temp=0.9, top_p=0.95, top_k=60, num_samples=4,  # Much higher for exploration
                needs_trust=True
            ),
            
            # Model 4: MAmmoTH-7B - More samples
            ModelConfig(
                "MAmmoTH-7B",
                '/kaggle/input/mammoth-7b-mistral',
                'TIGER-Lab/MAmmoTH-7B-Mistral',
                temp=0.8, top_p=0.92, top_k=50, num_samples=5,  # Increased from 3
                needs_trust=True
            ),
        ]
    
    def solve_with_model_single(self, problem: str, model, tokenizer, model_config: ModelConfig, 
                                sample_idx: int, stop_event: threading.Event) -> Optional[int]:
        """Generate single sample from a loaded model."""
        try:
            import torch
            
            # Check if we should stop early
            if stop_event.is_set():
                return None
            
            # Improved prompt with more explicit code execution guidance
            prompt = f"""You are solving a challenging International Mathematical Olympiad problem. 

Problem:
{problem}

Approach:
1. Read the problem carefully and identify what's being asked
2. Break down the problem into smaller steps
3. Use Python code to compute intermediate results - wrap code in ```python blocks
4. For combinatorics/number theory, write explicit Python calculations
5. Verify your answer makes sense given the constraints
6. The answer MUST be an integer between 0 and 99999

Important:
- Write Python code to calculate the answer when possible
- Use sympy for symbolic math, numpy for numerical computations
- Print intermediate results to verify your work
- Put your final answer in \\boxed{{answer}} format

Solution:"""

            # Try to use chat template if available, otherwise use plain text
            if hasattr(tokenizer, 'chat_template') and tokenizer.chat_template is not None:
                messages = [{"role": "user", "content": prompt}]
                text = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True
                )
            else:
                text = prompt

            inputs = tokenizer([text], return_tensors="pt")
            if hasattr(model, 'device'):
                inputs = inputs.to(model.device)

            with torch.no_grad():
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=3072,  # Increased from 2048 for longer solutions
                    temperature=model_config.temp,
                    do_sample=True,
                    top_p=model_config.top_p,
                    top_k=model_config.top_k,
                    pad_token_id=tokenizer.eos_token_id,
                )

            response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
            
            # Try Tool-Integrated Reasoning first (prioritize code)
            code = extract_code_from_response(response)
            answer = None
            
            if code:
                code_answer, _ = execute_python_code(code, timeout=20)  # Increased timeout
                if code_answer is not None:
                    answer = code_answer
            
            # Fallback to text extraction
            if answer is None:
                answer = extract_answer(response)
            
            if answer is not None and validate_answer(answer):
                return answer
            
            return None
            
        except Exception as e:
            return None
    
    def solve_with_model(self, problem: str, model_config: ModelConfig, 
                        stop_event: threading.Event) -> List[int]:
        """Solve using one model, generate multiple samples with early stopping."""
        try:
            import torch
            from transformers import AutoTokenizer, AutoModelForCausalLM
            import gc
            
            # Check early stop before loading model
            if stop_event.is_set():
                print(f"    [{model_config.name}] Skipped (early stop)")
                return []
            
            model_path = model_config.get_path()
            
            if not model_path:
                print(f"    [{model_config.name}] Not found")
                return []
            
            print(f"  [{model_config.name}] Loading ({model_config.num_samples} samples)...")
            
            try:
                tokenizer = AutoTokenizer.from_pretrained(
                    model_path, 
                    use_fast=True,
                    trust_remote_code=model_config.needs_trust
                )
                
                model = AutoModelForCausalLM.from_pretrained(
                    model_path,
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                    device_map="auto" if torch.cuda.is_available() else None,
                    trust_remote_code=model_config.needs_trust,
                    low_cpu_mem_usage=True
                )
            except Exception as e:
                print(f"    Load failed: {str(e)[:80]}")
                return []
            
            if torch.cuda.is_available():
                model.eval()
            
            answers = []
            
            # Parallel sampling for this model
            with ThreadPoolExecutor(max_workers=model_config.num_samples) as executor:
                futures = {
                    executor.submit(self.solve_with_model_single, problem, model, 
                                  tokenizer, model_config, i, stop_event): i
                    for i in range(model_config.num_samples)
                }
                
                for future in as_completed(futures):
                    if stop_event.is_set():
                        break
                    try:
                        answer = future.result()
                        if answer is not None:
                            answers.append(answer)
                    except Exception:
                        pass
            
            # Cleanup
            del model
            del tokenizer
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()
            
            if answers:
                print(f"    ‚Üí Got {len(answers)}/{model_config.num_samples} answers: {answers}")
            else:
                print(f"    ‚Üí No valid answers")
            
            return answers
            
        except Exception as e:
            print(f"  Error: {str(e)[:80]}")
            return []
    
    def smart_voting(self, all_answers: List[int]) -> Optional[int]:
        """
        Improved voting: require stronger consensus before accepting an answer
        """
        if not all_answers:
            return None
        
        vote_counts = Counter(all_answers)
        total_votes = len(all_answers)
        
        # Get most common answer
        best_answer, best_count = vote_counts.most_common(1)[0]
        
        # Require at least 30% consensus (increased from 40% since we have more samples)
        if best_count >= max(3, total_votes * 0.3):
            # Additional filtering: reject obvious bad patterns
            # If best answer is 0 or 1 and it's a complex problem, be suspicious
            if best_answer in [0, 1] and best_count < total_votes * 0.5:
                # Check if there's a better non-trivial answer
                for ans, count in vote_counts.most_common(5):
                    if ans not in [0, 1] and count >= max(2, total_votes * 0.2):
                        return ans
            
            # Reject 99999 unless strong consensus
            if best_answer == 99999 and best_count < total_votes * 0.4:
                for ans, count in vote_counts.most_common(5):
                    if ans != 99999 and count >= max(2, total_votes * 0.2):
                        return ans
            
            return best_answer
        
        # If no strong consensus, take most common that's not 0/1
        for ans, count in vote_counts.most_common(10):
            if ans not in [0, 1]:
                return ans
        
        # Last resort: just return most common
        return best_answer
    
    def solve_with_ensemble(self, problem: str) -> Optional[int]:
        """Run all models with parallel sampling and early stopping."""
        all_answers = []
        stop_event = threading.Event()
        
        print(f"  Running TIR ensemble with early stopping (threshold={self.early_stop_threshold})...")
        
        for model_config in self.model_configs:
            # Check for early stop after each model
            if stop_event.is_set():
                print(f"  ‚èπÔ∏è  Early stop triggered, skipping remaining models")
                break
            
            answers = self.solve_with_model(problem, model_config, stop_event)
            all_answers.extend(answers)
            
            # Check if we have consensus for early stopping
            if len(all_answers) >= self.early_stop_threshold:
                counts = Counter(all_answers)
                most_common_ans, count = counts.most_common(1)[0]
                if count >= self.early_stop_threshold:
                    print(f"  üéØ Early stop! Answer {most_common_ans} appeared {count} times")
                    stop_event.set()
                    break
        
        if not all_answers:
            return None
        
        # Smart voting
        best_answer = self.smart_voting(all_answers)
        vote_counts = Counter(all_answers)
        
        print(f"  Total samples: {len(all_answers)}")
        print(f"  Vote distribution: {dict(vote_counts.most_common(5))}")
        print(f"  Smart voting result: {best_answer}")
        
        return best_answer
    
    def solve_problem(self, problem_id: str, problem_text: str) -> int:
        answer = self.solve_with_ensemble(problem_text)
        
        if answer is None:
            answer = 0
        
        if not validate_answer(answer):
            answer = abs(answer) % 100000
        
        return answer

solver = EnsembleSolver(early_stop_threshold=10)

In [None]:
print("=" * 70)
print("TIR-ENABLED ENSEMBLE with Smart Voting & Early Stopping")
print("=" * 70)
print("Based on AIMO Prize 1 & 2 winning solutions + GPT-OSS-120B notebook:")
print("  ‚úì Tool-Integrated Reasoning (Python code execution)")
print("  ‚úì Enhanced code execution (auto math library imports)")
print("  ‚úì Parallel sampling per model (ThreadPoolExecutor)")
print("  ‚úì Early stopping when 5 samples agree (saves time!)")
print("  ‚úì Smart voting (filters outliers, requires consensus)")
print("  ‚úì Improved answer extraction (10+ patterns)")
print("  ‚úì ~20 total samples per problem (or less with early stop)")
print("")
print("Models (4 different models, 8 configs):")
print("  ‚Ä¢ Qwen2.5-Math-1.5B (3 parallel samples)")
print("  ‚Ä¢ Qwen2.5-Math-7B (6 parallel samples: 3 temps √ó 2 each)")
print("  ‚Ä¢ DeepSeek-Math-7B-RL (6 parallel samples: 3 temps √ó 2 each)")
print("  ‚Ä¢ MAmmoTH-7B-Mistral (3 parallel samples)")
print("=" * 70)

# Global tracking for accuracy calculation (local mode only)
predictions = {}
correct_count = 0
total_count = 0
ground_truth = {}

# Load reference data for local testing
reference_path = '/kaggle/input/ai-mathematical-olympiad-progress-prize-3/reference.csv'
if os.path.exists(reference_path):
    df = pd.read_csv(reference_path)
    # Store ground truth if available
    if "answer" in df.columns:
        ground_truth = dict(zip(df["id"], df["answer"]))
        print(f"\nüìã Loaded {len(ground_truth)} reference answers for accuracy tracking")

def predict(id_series: pl.Series, problem_series: pl.Series, *extra) -> pl.DataFrame:
    """Prediction function for inference server.

    The gateway unpacks the row DataFrame by columns, so we receive Series objects.
    For reference.csv (3 cols): id, problem, answer
    For test.csv (2 cols): id, problem
    
    Args:
        id_series: Series with the problem ID
        problem_series: Series with the problem text  
        *extra: Any additional columns (like 'answer' in reference.csv)

    Returns:
        DataFrame with 'id' and 'answer' columns
    """
    global correct_count, total_count, predictions

    try:
        # Extract values from Series
        question_id = id_series[0]
        question_text = problem_series[0]
        
        print("=" * 60)
        print(f"ID: {question_id}")
        print(f"Question: {question_text[:200]}...")
        
        answer = solver.solve_problem(question_id, question_text)
        
        # Store prediction
        predictions[question_id] = answer
        
        # Check accuracy if ground truth available
        total_count += 1
        if question_id in ground_truth:
            gt = ground_truth[question_id]
            is_correct = (answer == gt)
            if is_correct:
                correct_count += 1
            status = "‚úÖ" if is_correct else "‚ùå"
            print(f"\nAnswer: {answer} | Ground Truth: {gt} | {status}")
            print(f"üìä Running Accuracy: {correct_count}/{total_count} ({100*correct_count/total_count:.1f}%)")
        else:
            print(f"\nAnswer: {answer}")
        
        print("=" * 60 + "\n")
        
        return pl.DataFrame({"id": [question_id], "answer": [answer]})
    except Exception as e:
        print(f"Prediction error: {e}")
        import traceback
        traceback.print_exc()
        
        # Try to extract ID
        try:
            fallback_id = id_series[0]
        except:
            fallback_id = "unknown"
            
        return pl.DataFrame({"id": [fallback_id], "answer": [0]})

inference_server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)

# Auto-detect: competition rerun vs local testing
if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    print("\n[COMPETITION MODE] Starting inference server...")
    inference_server.serve()
else:
    print("\n[LOCAL MODE] Running local gateway on reference.csv...")
    inference_server.run_local_gateway((reference_path,))
    
    # Print final accuracy summary
    if ground_truth and total_count > 0:
        print("\n" + "=" * 60)
        print("üìä FINAL ACCURACY SUMMARY")
        print("=" * 60)
        print(f"Correct: {correct_count}/{total_count}")
        print(f"Accuracy: {100*correct_count/total_count:.1f}%")
        print("=" * 60)
        
        # Show details of incorrect predictions
        incorrect = []
        for qid, pred in predictions.items():
            if qid in ground_truth:
                gt = ground_truth[qid]
                if pred != gt:
                    incorrect.append((qid, pred, gt))
        
        if incorrect:
            print("\n‚ùå Incorrect Predictions:")
            for qid, pred, gt in incorrect:
                print(f"  {qid}: predicted={pred}, actual={gt}")
