In [1]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import os
import sys
import re
import time
import polars as pl
import pandas as pd
import torch
import math
import numpy as np
import sympy
import itertools
import collections
import fractions
import gc
import signal
from collections import Counter
from io import StringIO
import kaggle_evaluation.aimo_3_inference_server
from transformers import AutoModelForCausalLM, AutoTokenizer

# ==========================================
# CONFIGURATION
# ==========================================
MODEL_PATH = "/kaggle/input/math_numina_7b_tir/pytorch/default/1/NuminaMath-7B-TIR"

class ProductionAIMO3Solver:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.is_loaded = False
        self.problem_cache = {}

    def load(self):
        if self.is_loaded: return
        
        print(f"‚è≥ Loading NuminaMath-7B-TIR...")
        
        if not os.path.exists(MODEL_PATH):
            print(f"‚ùå Model not found at {MODEL_PATH}")
            return

        try:
            self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
            
            max_mem = {0: "11GiB", 1: "11GiB", "cpu": "30GiB"}
            self.model = AutoModelForCausalLM.from_pretrained(
                MODEL_PATH,
                torch_dtype=torch.float16,
                device_map="auto",
                max_memory=max_mem,
                local_files_only=True,
                low_cpu_mem_usage=True
            )
            print("‚úÖ Model Loaded (Float16 Split)!")
            self.is_loaded = True
            
        except Exception as e:
            print(f"‚ùå Load Failed: {e}")
            return

    def run_python_code_with_timeout(self, code, timeout=10):
        """REAL Python execution with TIMEOUT"""
        sys.set_int_max_str_digits(0)
        local_scope = {
            "math": math, "np": np, "sympy": sympy, "itertools": itertools,
            "collections": collections, "fractions": fractions, "print": print,
            "sys": sys, "Counter": Counter
        }
        
        old_stdout = sys.stdout
        redirected_output = sys.stdout = StringIO()
        
        def timeout_handler(signum, frame):
            raise TimeoutError("Code execution timed out")
        
        signal.signal(signal.SIGALRM, timeout_handler)
        signal.alarm(timeout)
        
        try:
            code = re.sub(r"```python|```", "", code)
            exec(code, local_scope)
            signal.alarm(0)
            sys.stdout = old_stdout
            return redirected_output.getvalue().strip(), "SUCCESS"
        except TimeoutError:
            signal.alarm(0)
            sys.stdout = old_stdout
            return "Execution timed out", "TIMEOUT"
        except Exception as e:
            signal.alarm(0)
            sys.stdout = old_stdout
            return f"Runtime Error: {e}", "ERROR"

    def extract_answer(self, text):
        try:
            numbers = re.findall(r'\b\d{1,5}\b', text)
            if numbers:
                answer = int(numbers[-1])
                if 0 <= answer <= 99999:
                    return answer
        except:
            pass
        return None

    def solve(self, problem_text, max_retries=1):
        if not self.is_loaded: 
            self.load()
        if not self.is_loaded: 
            return 0
        
        problem_hash = hash(problem_text)
        if problem_hash in self.problem_cache:
            return self.problem_cache[problem_hash]
        
        print(f"üß† Solving: {problem_text[:50]}...")
        start_time = time.time()
        
        prompt = f"""Solve this problem with concise Python code. Print only the final integer answer.

Problem: {problem_text}

Python code:
```python
"""
        
        for attempt in range(max_retries + 1):
            gc.collect()
            torch.cuda.empty_cache()
            
            try:
                inputs = self.tokenizer(
                    prompt, 
                    return_tensors="pt", 
                    max_length=1024,
                    truncation=True,
                    padding=True
                ).to(self.model.device)
                
                with torch.no_grad():
                    generated_ids = self.model.generate(
                        **inputs,
                        max_new_tokens=512,
                        do_sample=False,
                        pad_token_id=self.tokenizer.eos_token_id,
                        repetition_penalty=1.1
                    )
                
                response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
                response = response[len(prompt):] if response.startswith(prompt) else response
                
                code_match = re.search(r"```python(.*?)```", response, re.DOTALL)
                if code_match:
                    code = code_match.group(1).strip()
                    if code:
                        output, status = self.run_python_code_with_timeout(code, timeout=15)
                        
                        if status == "SUCCESS":
                            answer = self.extract_answer(output)
                            if answer is not None:
                                elapsed = time.time() - start_time
                                print(f"    ‚úÖ Code success: {answer} ({elapsed:.1f}s)")
                                self.problem_cache[problem_hash] = answer
                                return answer
                        else:
                            print(f"    ‚ùå {status}: {output}")
                
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print("    ‚ö†Ô∏è OOM - clearing cache")
                    torch.cuda.empty_cache()
                    continue
                else:
                    print(f"    ‚ùå Runtime error: {e}")
            except Exception as e:
                print(f"    ‚ùå Generation error: {e}")
                continue
        
        elapsed = time.time() - start_time
        print(f"‚ùå Failed after {elapsed:.1f}s")
        return 0

# ==========================================
# ACTUAL REFERENCE.CSV TESTING
# ==========================================
solver = ProductionAIMO3Solver()

def test_reference_problems():
    """ACTUALLY test on reference.csv problems"""
    try:
        # Load the reference problems
        reference_df = pd.read_csv('/kaggle/input/ai-mathematical-olympiad-progress-prize-3/reference.csv')
        print(f"üìä TESTING ON {len(reference_df)} REFERENCE PROBLEMS")
        print("=" * 80)
        
        correct = 0
        total = len(reference_df)
        total_time = 0
        
        for i in range(min(3, total)):  # Test first 3 to save time
            row = reference_df.iloc[i]
            problem_text = row['problem']
            true_answer = row['answer']
            
            print(f"\nüî¢ REFERENCE PROBLEM {i+1}/{min(3, total)}")
            print(f"   TRUE ANSWER: {true_answer}")
            print(f"   PROBLEM: {problem_text[:80]}...")
            
            start_time = time.time()
            predicted = solver.solve(problem_text)
            elapsed = time.time() - start_time
            total_time += elapsed
            
            status = "‚úÖ CORRECT" if predicted == true_answer else "‚ùå WRONG"
            print(f"   {status} | PREDICTED: {predicted} | TIME: {elapsed:.1f}s")
            
            if predicted == true_answer:
                correct += 1
            
            torch.cuda.empty_cache()
            time.sleep(1)
        
        accuracy = correct / min(3, total) * 100
        avg_time = total_time / min(3, total)
        
        print(f"\nüéØ REFERENCE SET PERFORMANCE: {correct}/{min(3, total)} correct ({accuracy:.1f}%)")
        print(f"‚è±Ô∏è  Average time: {avg_time:.1f}s per problem")
        
        if accuracy >= 50:
            print("üéâ READY FOR COMPETITION SUBMISSION!")
        else:
            print("‚ö†Ô∏è  Needs improvement before submission")
            
        return accuracy
        
    except Exception as e:
        print(f"‚ùå Error loading reference problems: {e}")
        return 0

def predict(id_: pl.Series, problem: pl.Series) -> pl.DataFrame:
    id_val = id_.item(0)
    problem_text = problem.item(0)
    try:
        answer = solver.solve(problem_text)
    except:
        answer = 0
    return pl.DataFrame({'id': [id_val], 'answer': [answer]})

inference_server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    print("üöÄ Starting Production Server...")
    inference_server.serve()
else:
    print("üî¨ Local Test Mode...")
    # FINALLY TEST ON ACTUAL REFERENCE PROBLEMS
    accuracy = test_reference_problems()

üî¨ Local Test Mode...
üìä TESTING ON 10 REFERENCE PROBLEMS

üî¢ REFERENCE PROBLEM 1/3
   TRUE ANSWER: 336
   PROBLEM: Let $ABC$ be an acute-angled triangle with integer side lengths and $AB<AC$. Poi...
‚è≥ Loading NuminaMath-7B-TIR...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

‚úÖ Model Loaded (Float16 Split)!
üß† Solving: Let $ABC$ be an acute-angled triangle with integer...
‚ùå Failed after 63.0s
   ‚ùå WRONG | PREDICTED: 0 | TIME: 107.2s

üî¢ REFERENCE PROBLEM 2/3
   TRUE ANSWER: 32951
   PROBLEM: Define a function $f \colon \mathbb{Z}_{\geq 1} \to \mathbb{Z}_{\geq 1}$ by
\beg...
üß† Solving: Define a function $f \colon \mathbb{Z}_{\geq 1} \t...
‚ùå Failed after 62.5s
   ‚ùå WRONG | PREDICTED: 0 | TIME: 62.5s

üî¢ REFERENCE PROBLEM 3/3
   TRUE ANSWER: 21818
   PROBLEM: A tournament is held with $2^{20}$ runners each of which has a different running...
üß† Solving: A tournament is held with $2^{20}$ runners each of...
‚ùå Failed after 62.5s
   ‚ùå WRONG | PREDICTED: 0 | TIME: 62.5s

üéØ REFERENCE SET PERFORMANCE: 0/3 correct (0.0%)
‚è±Ô∏è  Average time: 77.4s per problem
‚ö†Ô∏è  Needs improvement before submission
