In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import os
import sys
import re
import time
import polars as pl
import pandas as pd
import torch
import math
import numpy as np
import sympy
import itertools
import collections
import fractions
import gc
import signal
from collections import Counter
from io import StringIO
import kaggle_evaluation.aimo_3_inference_server
from transformers import AutoModelForCausalLM, AutoTokenizer

# ==========================================
# 0. MEMORY OPTIMIZATION
# ==========================================
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# ==========================================
# 1. CONFIGURATION
# ==========================================
MODEL_PATH = "/kaggle/input/math_numina_7b_tir/pytorch/default/1/NuminaMath-7B-TIR"

class SpecialistAIMO3Solver:
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.is_loaded = False
        self.problem_cache = {}

    def load(self):
        if self.is_loaded: return
        
        print(f"‚è≥ Loading NuminaMath-7B-TIR...")
        
        if not os.path.exists(MODEL_PATH):
            print(f"‚ùå Model not found at {MODEL_PATH}")
            return

        try:
            self.tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
            
            # --- DYNAMIC HARDWARE DETECTION ---
            # Check how many GPUs are actually available
            num_gpus = torch.cuda.device_count()
            print(f"üîç Detected {num_gpus} GPU(s)")
            
            if num_gpus == 0:
                print("‚ùå No GPU detected! This model cannot run on CPU within time limits.")
                return
                
            # Check VRAM of the first GPU
            vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
            print(f"üîç Primary GPU VRAM: {vram_gb:.1f} GB")
            
            if vram_gb > 24:
                # Case A: H100 (80GB), A100 (40GB), L4 (24GB) -> Single Device Load
                print("üöÄ High-VRAM GPU detected! Loading full model on GPU 0...")
                self.model = AutoModelForCausalLM.from_pretrained(
                    MODEL_PATH,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    local_files_only=True,
                    low_cpu_mem_usage=True
                )
            elif num_gpus > 1:
                # Case B: 2x T4 (15GB each) -> Split across devices
                print("‚ö†Ô∏è Dual T4 GPUs detected. Splitting model...")
                max_mem = {0: "11GiB", 1: "11GiB", "cpu": "30GiB"}
                self.model = AutoModelForCausalLM.from_pretrained(
                    MODEL_PATH,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    max_memory=max_mem,
                    local_files_only=True,
                    low_cpu_mem_usage=True
                )
            else:
                # Case C: Single T4 (15GB) -> Very tight fit, might offload to CPU
                print("‚ö†Ô∏è Single Small GPU detected. Model might offload to CPU (Slow).")
                self.model = AutoModelForCausalLM.from_pretrained(
                    MODEL_PATH,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    local_files_only=True,
                    low_cpu_mem_usage=True
                )

            print("‚úÖ Model Loaded (Float16)!")
            self.is_loaded = True
            
        except Exception as e:
            print(f"‚ùå Load Failed: {e}")
            return

    def identify_topic(self, problem_text):
        """Enhanced topic classification with pattern matching"""
        problem_lower = problem_text.lower()
        
        # Multi-keyword scoring
        topic_scores = {
            "geometry": 0, "number_theory": 0, 
            "combinatorics": 0, "algebra": 0
        }
        
        # Geometry patterns
        geo_patterns = ['triangle', 'circle', 'angle', 'perimeter', 'area', 
                        'acute-angled', 'right-angled', 'congruent', 'similar',
                        'bisector', 'circumcircle', 'radius', 'diameter']
        for pattern in geo_patterns:
            if pattern in problem_lower:
                topic_scores["geometry"] += 2
        
        # Number theory patterns  
        nt_patterns = ['mod', 'prime', 'divisible', 'gcd', 'lcm', 'remainder', 
                      'modulo', 'coprime', 'factor', 'multiple', 'integer solution']
        for pattern in nt_patterns:
            if pattern in problem_lower:
                topic_scores["number_theory"] += 2
                
        # Combinatorics patterns
        comb_patterns = ['combination', 'permutation', 'probability', 'choose', 
                        'arrangement', 'tournament', 'runner', 'race', 'count']
        for pattern in comb_patterns:
            if pattern in problem_lower:
                topic_scores["combinatorics"] += 2
        
        # Algebra patterns
        algebra_patterns = ['function', 'polynomial', 'equation', 'solve for', 'root',
                           'coefficient', 'variable', 'expression']
        for pattern in algebra_patterns:
            if pattern in problem_lower:
                topic_scores["algebra"] += 2
    
        # Return highest scored topic
        best_topic = max(topic_scores, key=topic_scores.get)
        return best_topic if topic_scores[best_topic] > 0 else "general"

    def get_specialist_prompt(self, problem_text, topic):
        """Domain-Specific Prompts to Force Correct Reasoning"""
        
        base_prompt = f"Problem: {problem_text}\n\n"
        
        if topic == "geometry":
            return base_prompt + """Solve this GEOMETRY problem using coordinate geometry or sympy.
CRITICAL: Your code MUST print the final integer answer using print().
Example: print(42)

Python code:
```python
import sympy as sp
import math
# Use coordinate geometry or sympy geometric objects
"""
        
        elif topic == "number_theory":
            return base_prompt + """Solve this NUMBER THEORY problem using modular arithmetic.
CRITICAL: Your code MUST print the final integer answer using print().
Example: print(42)

Python code:
```python
import math
import sympy
# Use modular arithmetic and number theory functions
"""
        
        elif topic == "combinatorics":
            return base_prompt + """Solve this COMBINATORICS problem using itertools or counting principles.
Use efficient combinatorial algorithms.
Key techniques: combinations, permutations, inclusion-exclusion.
IMPORTANT: Use itertools for combinatorial operations.

Python code:
```python
import itertools
import math
# Use combinatorial functions and counting
"""
        
        elif topic == "algebra":
            return base_prompt + """Solve this ALGEBRA problem using symbolic mathematics.
Use sympy for equation solving and simplification.
Key techniques: equation solving, polynomial manipulation.
IMPORTANT: Use sympy for algebraic manipulations.

Python code:
```python
import sympy as sp
# Use symbolic algebra and equation solving
"""
        
        else:  # general
            return base_prompt + """Solve this mathematical problem using efficient Python code.
Use appropriate mathematical libraries and avoid brute force.
Print the final integer answer.

Python code:
```python
import math
import sympy
# Use appropriate mathematical approach
"""

    def run_python_code(self, code):
        """REAL Python Execution with Enhanced Auto-Print"""
        sys.set_int_max_str_digits(0)
        local_scope = {
            "math": math, "np": np, "numpy": np, "sympy": sympy, "itertools": itertools,
            "collections": collections, "fractions": fractions, "print": print,
            "sys": sys, "Counter": Counter, "sp": sympy
        }
        
        old_stdout = sys.stdout
        redirected_output = sys.stdout = StringIO()
        
        def timeout_handler(signum, frame):
            raise TimeoutError("Execution Timed Out")
        
        # 10-second timeout for complex problems
        signal.signal(signal.SIGALRM, timeout_handler)
        signal.alarm(10)
        
        try:
            # Clean the code
            code = re.sub(r"```python|```", "", code).strip()
            
            # FIX: Add print statement if the last line looks like a result
            lines = [line.strip() for line in code.split('\n') if line.strip() and not line.startswith('#')]
            
            if lines and 'print(' not in code:
                last_line = lines[-1]
                
                # Case 1: Simple variable assignment (result = ...)
                if '=' in last_line and not last_line.startswith(('def ', 'class ', 'import ', 'from ')):
                    var_name = last_line.split('=')[0].strip()
                    if re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', var_name):
                        code += f"\nprint({var_name})"
                
                # Case 2: Standalone variable or function call
                elif (re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*$', last_line) or 
                      re.match(r'^[a-zA-Z_][a-zA-Z0-9_]*\(', last_line)):
                    code += f"\nprint({last_line})"
                
                # Case 3: Mathematical expression
                elif re.match(r'^[\d+\-*/(). ]+$', last_line) and len(last_line) < 50:
                    code += f"\nprint({last_line})"
            
            exec(code, local_scope)
            signal.alarm(0)
            sys.stdout = old_stdout
            return redirected_output.getvalue().strip(), "SUCCESS"
        except TimeoutError:
            signal.alarm(0)
            sys.stdout = old_stdout
            return "Error: Code took too long (>10s).", "TIMEOUT"
        except Exception as e:
            signal.alarm(0)
            sys.stdout = old_stdout
            return f"Runtime Error: {e}", "ERROR"

    def extract_answer(self, text):
        """Enhanced answer extraction with multiple patterns"""
        try:
            # Look for boxed answers \boxed{123}
            boxed_match = re.search(r'\\boxed\{(\d+)\}', text)
            if boxed_match:
                return int(boxed_match.group(1))
                
            # Look for final answer patterns
            final_patterns = [
                r'final answer[:\s]*(\d+)',
                r'answer[:\s]*(\d+)',
                r'result[:\s]*(\d+)',
                r'solution[:\s]*(\d+)'
            ]
            for pattern in final_patterns:
                match = re.search(pattern, text.lower())
                if match:
                    return int(match.group(1))
            
            # Extract all numbers and take the most plausible one
            numbers = re.findall(r'\b\d{1,5}\b', text)
            if numbers:
                # Prefer numbers that appear in final positions
                candidates = []
                for i, num in enumerate(numbers):
                    num_int = int(num)
                    if 0 <= num_int <= 99999:
                        # Weight by position (later numbers are more likely answers)
                        position_weight = i / len(numbers)
                        candidates.append((num_int, position_weight))
                
                if candidates:
                    # Return the number with highest position weight
                    return max(candidates, key=lambda x: x[1])[0]
                    
        except Exception as e:
            print(f"‚ö†Ô∏è Answer extraction error: {e}")
        
        return None

    def verify_solution(self, problem_text, generated_code, answer):
        """Basic sanity checks on generated solutions"""
        if answer is None:
            return False
            
        # Check if answer is within reasonable bounds
        if not (0 <= answer <= 99999):
            return False
            
        # Check if code actually computes something mathematical
        math_keywords = ['math.', 'sympy.', 'np.', 'import', 'def ', 'calculate', 'compute']
        if not any(keyword in generated_code for keyword in math_keywords):
            print("‚ö†Ô∏è Generated code doesn't look mathematical")
            return False
            
        return True

    def solve(self, problem_text, max_retries=1):
        if not self.is_loaded: 
            self.load()
        if not self.is_loaded: 
            return 0
        
        problem_hash = hash(problem_text)
        if problem_hash in self.problem_cache:
            return self.problem_cache[problem_hash]
        
        # STEP 1: Classify Problem Domain
        topic = self.identify_topic(problem_text)
        print(f"üß† [{topic.upper()}] Solving: {problem_text[:50]}...")
        
        start_time = time.time()
        
        # STEP 2: Get Domain-Specific Prompt
        prompt_content = self.get_specialist_prompt(problem_text, topic)
        
        # --- MANUAL CHAT TEMPLATE ---
        full_prompt = f"<|user|>\n{prompt_content}\nPlease write Python code in ```python ... ``` blocks.\n<|assistant|>\n"
        
        for attempt in range(max_retries + 1):
            gc.collect()
            torch.cuda.empty_cache()
            
            try:
                # Use the manual chat template prompt
                inputs = self.tokenizer(
                    full_prompt, 
                    return_tensors="pt", 
                    max_length=1024,
                    truncation=True,
                    padding=True
                ).to(self.model.device)
                
                print(f"    Generating code (attempt {attempt+1})...")
                generation_start = time.time()
                
                with torch.no_grad():
                    generated_ids = self.model.generate(
                        **inputs,
                        max_new_tokens=512,
                        do_sample=False,
                        pad_token_id=self.tokenizer.eos_token_id,
                        repetition_penalty=1.1
                    )
                
                gen_time = time.time() - generation_start
                print(f"    Generation took {gen_time:.1f}s")
                
                # KEY FIX: Use the FULL text to find the code block
                full_response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
                
                # DEBUG: Show what the model actually generated
                new_text = full_response[len(full_prompt):]
                print(f"    Model response preview: {new_text[:200]}...")
                
                code_match = re.search(r"```python(.*?)```", full_response, re.DOTALL)
                if code_match:
                    code = code_match.group(1).strip()
                    print(f"    Found code block: {len(code)} characters")
                    if code:
                        # print(f"    Code preview: {code[:100]}...") # Optional debug
                        output, status = self.run_python_code(code)
                        
                        if status == "SUCCESS":
                            answer = self.extract_answer(output)
                            if answer is not None and self.verify_solution(problem_text, code, answer):
                                elapsed = time.time() - start_time
                                print(f"    ‚úÖ {topic.title()} success: {answer} ({elapsed:.1f}s)")
                                self.problem_cache[problem_hash] = answer
                                return answer
                            else:
                                print(f"    ‚ùå Code ran but no valid answer. Output: '{output}'")
                        else:
                            print(f"    ‚ùå {status}: {output}")
                else:
                    # Fallback: If no closing tag, maybe it just stopped?
                    if "```python" in prompt_content:
                        print(f"    ‚ö†Ô∏è No closing tag found. Trying raw text...")
                        output, status = self.run_python_code(new_text)
                        if status == "SUCCESS":
                             answer = self.extract_answer(output)
                             if answer is not None and self.verify_solution(problem_text, new_text, answer):
                                 return answer

                    print(f"    ‚ùå No valid code block found")
                
            except RuntimeError as e:
                if "out of memory" in str(e):
                    print("    ‚ö†Ô∏è OOM - clearing cache")
                    torch.cuda.empty_cache()
                    continue
                else:
                    print(f"    ‚ùå Runtime error: {e}")
            except Exception as e:
                print(f"    ‚ùå Generation error: {e}")
                continue
        
        elapsed = time.time() - start_time
        print(f"‚ùå {topic.upper()} failed after {elapsed:.1f}s")
        return 0

# ==========================================
# EXECUTION
# ==========================================
solver = SpecialistAIMO3Solver()

# Add this quick test
def test_basic_generation():
    print("üß™ Testing basic code generation...")
    solver.load()
    
    if not solver.is_loaded:
        print("‚ùå Skipping test because model failed to load.")
        return

    # Use manual chat template in test too
    test_prompt = "<|user|>\nWrite Python code to calculate 5 factorial and print the result.\n<|assistant|>\n"
    
    inputs = solver.tokenizer(test_prompt, return_tensors="pt").to(solver.model.device)
    with torch.no_grad():
        generated_ids = solver.model.generate(
            **inputs,
            max_new_tokens=100,
            do_sample=False,
            pad_token_id=solver.tokenizer.eos_token_id
        )
    
    # FIX: Search the FULL response
    full_response = solver.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"Generated (Preview): {full_response[len(test_prompt):]}")
    
    # Check if it contains valid code
    code_match = re.search(r"```python(.*?)```", full_response, re.DOTALL)
    if code_match:
        code = code_match.group(1).strip()
        print(f"Extracted code: {code}")
        output, status = solver.run_python_code(code)
        print(f"Execution: {status}, Output: '{output}'")
    else:
        # Fallback check
        print("‚ö†Ô∏è Regex failed on full text. Trying loose match...")
        if "import math" in full_response:
             # Manually grabbing code for test
             code = full_response.split("```python")[1].split("```")[0]
             output, status = solver.run_python_code(code)
             print(f"Execution (Fallback): {status}, Output: '{output}'")
        else:
             print("‚ùå No code block generated")

def quick_integration_test():
    """Test the complete pipeline"""
    print("\n" + "="*60)
    print("üöÄ TESTING COMPLETE PIPELINE")
    print("="*60)
    
    test_cases = [
        {
            "problem": "Calculate the number of positive integers less than 100 that are divisible by 3 or 5.",
            "expected": "number_theory",
            "hint": "Should use modular arithmetic"
        },
        {
            "problem": "Find the area of a triangle with sides 5, 12, and 13.",
            "expected": "geometry", 
            "hint": "Should use coordinate geometry"
        }
    ]
    
    for i, test in enumerate(test_cases):
        print(f"\nüß™ Test {i+1}: {test['problem'][:50]}...")
        
        # Test classification
        topic = solver.identify_topic(test['problem'])
        print(f"    Classification: {topic} (expected: {test['expected']})")
        
        # Test prompt generation
        prompt = solver.get_specialist_prompt(test['problem'], topic)
        print(f"    Prompt length: {len(prompt)} chars")
        print(f"    Hint: {test['hint']}")
        
        # Quick generation test (first 100 chars of response)
        try:
            inputs = solver.tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(solver.model.device)
            with torch.no_grad():
                generated_ids = solver.model.generate(
                    **inputs,
                    max_new_tokens=100,  # Just get a preview
                    do_sample=False
                )
            preview = solver.tokenizer.decode(generated_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
            print(f"    Generation preview: {preview[:100]}...")
        except Exception as e:
            print(f"    ‚ùå Generation failed: {e}")

def test_reference_problems():
    """Test the specialist router on actual AIMO problems"""
    print("\n" + "="*60)
    print("üöÄ TESTING SPECIALIST ROUTER ON REFERENCE PROBLEMS")
    print("="*60)
    
    try:
        possible_paths = [
            '/kaggle/input/ai-mathematical-olympiad-progress-prize-3/reference.csv',
            'reference.csv'
        ]
        data_path = next((p for p in possible_paths if os.path.exists(p)), None)
        
        if not data_path:
            print("‚ö†Ô∏è Reference CSV not found.")
            return 0

        reference_df = pd.read_csv(data_path)
        print(f"Testing on {len(reference_df)} reference problems...")
        
        # Test first 2 problems to start
        for i in range(min(2, len(reference_df))):
            row = reference_df.iloc[i]
            # Safe column access
            prob = row.get('problem', row.get(reference_df.columns[1]))
            ans = row.get('answer', row.get(reference_df.columns[-1]))
            
            try: ans = int(float(ans))
            except: pass
            
            print(f"\nüî¢ REFERENCE PROBLEM {i+1}/2")
            print(f"   Expected answer: {ans}")
            print(f"   Problem preview: {prob[:80]}...")
            
            result = solver.solve(prob)
            
            status = "‚úÖ CORRECT" if result == ans else "‚ùå WRONG"
            print(f"   {status} | Got: {result}")
            
            # Clear memory between problems
            torch.cuda.empty_cache()
            time.sleep(1)
            
    except Exception as e:
        print(f"‚ùå Error testing reference problems: {e}")

def predict(id_: pl.Series, problem: pl.Series) -> pl.DataFrame:
    id_val = id_.item(0)
    problem_text = problem.item(0)
    try:
        answer = solver.solve(problem_text)
    except:
        answer = 0
    return pl.DataFrame({'id': [id_val], 'answer': [answer]})

inference_server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    print("üöÄ Starting Production Server...")
    inference_server.serve()
else:
    print("üî¨ Local Test Mode...")
    # Run basic test first
    test_basic_generation()
    
    # If basic test works, test integration
    if solver.is_loaded:
        quick_integration_test()
        
        # Uncomment to test reference problems after integration test passes
        test_reference_problems()

üî¨ Local Test Mode...
üß™ Testing basic code generation...
‚è≥ Loading NuminaMath-7B-TIR...
üîç Detected 0 GPU(s)
‚ùå No GPU detected! This model cannot run on CPU within time limits.
‚ùå Skipping test because model failed to load.
