In [None]:
# AIMO 3 - Diversity Ensemble Solver
# Implements "The Architect", "The Hacker", and "The Formalist"
# to break the correlation of errors in standard ensembles.

%pip uninstall --yes 'keras' 'matplotlib' 'scikit-learn' 'tensorflow'
import warnings
warnings.simplefilter('ignore')
import os, sys, subprocess, gc, re, math, time, queue, threading, contextlib
from typing import Optional
from jupyter_client import KernelManager
from collections import Counter, defaultdict
from concurrent.futures import as_completed, ThreadPoolExecutor
import pandas as pd
import polars as pl
from openai import OpenAI
from openai_harmony import (
    HarmonyEncodingName, load_harmony_encoding, SystemContent, ReasoningEffort, 
    ToolNamespaceConfig, Author, Message, Role, TextContent, Conversation
)
from transformers import set_seed
import kaggle_evaluation.aimo_3_inference_server

def set_env(input_archive, temp_dir):
    if not os.path.exists(temp_dir):
        os.makedirs(temp_dir, exist_ok=True)
        subprocess.run(['tar', '-xzf', input_archive, '-C', temp_dir], check=True)
    subprocess.run([sys.executable, '-m', 'pip', 'install', '--no-index', '--find-links', f'{temp_dir}/wheels', 
                    'unsloth', 'trl', 'vllm', 'openai_harmony'], check=True)

set_env('/kaggle/input/aimo-3-utils/wheels.tar.gz', '/kaggle/tmp/setup')

for k, v in [('TRANSFORMERS_NO_TF', '1'), ('TRANSFORMERS_NO_FLAX', '1'), 
             ('CUDA_VISIBLE_DEVICES', '0'), ('TOKENIZERS_PARALLELISM', 'false'),
             ('TRITON_PTXAS_PATH', '/usr/local/cuda/bin/ptxas'),
             ('TIKTOKEN_ENCODINGS_BASE', '/kaggle/tmp/setup/tiktoken_encodings')]:
    os.environ[k] = v

In [None]:
class CFG:
    # -------------------------------------------------------------------------
    # DIVERSITY PROMPTS
    # -------------------------------------------------------------------------
    
    # 1. The Architect (Standard Chain-of-Thought)
    prompt_architect = (
        'You are an elite mathematical problem solver. Your goal is to find the correct answer '
        'through rigorous reasoning. \n'
        '- Use Python to verify small steps.\n'
        '- Break down the problem into sub-components.\n'
        '- Output the final integer in \\boxed{}.'
    )
    
    # 2. The Hacker (Brute Force / Simulation) - High Temperature
    prompt_hacker = (
        'You are a Python expert specializing in brute-force search and simulation. \n'
        'Do not waste time on elegant proofs. Focus on computational attacks.\n'
        '1. Translate the problem into a checkable Python function `is_solution(x)`. \n'
        '2. Write an efficient loop or search algorithm to find the answer. \n'
        '3. If the space is infinite, check the first 1,000,000 cases or use Monte Carlo. \n'
        '4. Print the final answer found by your code inside \\boxed{}.'
    )
    
    # 3. The Formalist (Symbolic Algebra) - Low Temperature
    prompt_formalist = (
        'You are a theoretical mathematician. You distrust floating point arithmetic. \n'
        '1. Model every constraint using `sympy` symbols, Equations, and Solvers. \n'
        '2. Use `sympy.ntheory` for number theory and `sympy.geometry` for geometry. \n'
        '3. Avoid loops. seek closed-form analytical solutions using Python as a proof assistant. \n'
        '4. Output the exact integer answer inside \\boxed{}.'
    )
    
    # Role Schedule: Flattened list of roles to cycle through for each attempt
    # We prioritize Hacker for speed/finding numerical tricks, Architect for general coverage.
    attempts = 12 
    active_roles = [
        prompt_hacker, prompt_architect, prompt_formalist, 
        prompt_hacker, prompt_architect, prompt_formalist,
        prompt_hacker, prompt_architect, prompt_architect,
        prompt_hacker, prompt_architect, prompt_formalist
    ]

    tool_prompt = (
        'Use this tool to execute Python code. The environment is stateful. '
        'Always use print() to see output.'
    )
    
    served_model_name = 'gpt-oss'
    model_path = '/kaggle/input/gpt-oss-120b/transformers/default/1'
    kv_cache_dtype, dtype = 'fp8_e4m3', 'auto'
    
    notebook_limit = 17400
    base_problem_timeout = 300
    high_problem_timeout = 900
    
    context_tokens = 65536
    temperature = 0.9 # Base temp, modifiers applied in solver
    workers = 16
    turns = 60 # Shorter turns needed for Hacker, longer for Architect
    batch_size = 256
    
    seed = 42
    min_p = 0.02
    gpu_memory_utilization = 0.95
    
    # Speed optimizations
    search_tokens = 32

set_seed(CFG.seed)

In [None]:
class AIMO3Template:
    def get_system_content(self, system_prompt: str, tool_config: ToolNamespaceConfig) -> SystemContent:
        return (
            SystemContent.new()
            .with_model_identity(system_prompt)
            .with_reasoning_effort(reasoning_effort=ReasoningEffort.HIGH)
            .with_tools(tool_config)
        )

    def apply_chat_template(self, system_prompt: str, user_prompt: str, tool_config: ToolNamespaceConfig) -> list[Message]:
        system_content = self.get_system_content(system_prompt, tool_config)        
        system_message = Message.from_role_and_content(Role.SYSTEM, system_content)
        user_message = Message.from_role_and_content(Role.USER, user_prompt)
        return [system_message, user_message]

class AIMO3Sandbox:
    _port_lock, _next_port = threading.Lock(), 50000
    @classmethod
    def _get_next_ports(cls, count=5):
        with cls._port_lock:
            ports = list(range(cls._next_port, cls._next_port + count))
            cls._next_port += count
            return ports
    
    def __init__(self, timeout):
        self._default_timeout, self._owns_kernel, self._client, self._km = timeout, False, None, None
        ports = self._get_next_ports(5)
        env = os.environ.copy()
        env.update({'PYDEVD_DISABLE_FILE_VALIDATION': '1', 'PYDEVD_WARN_EVALUATION_TIMEOUT': '0',
                   'JUPYTER_PLATFORM_DIRS': '1', 'PYTHONWARNINGS': 'ignore', 'MPLBACKEND': 'Agg'})
        self._km = KernelManager()
        self._km.shell_port, self._km.iopub_port, self._km.stdin_port, self._km.hb_port, self._km.control_port = ports
        self._km.start_kernel(env=env, extra_arguments=['--Application.log_level=CRITICAL'])
        self._client = self._km.blocking_client()
        self._client.start_channels()
        self._client.wait_for_ready(timeout=self._default_timeout)
        self._owns_kernel = True
        self.execute('import math, numpy, sympy, mpmath, itertools, collections\nmpmath.mp.dps = 64\n')
    
    def execute(self, code, timeout=None):
        effective_timeout = timeout or self._default_timeout
        msg_id = self._client.execute(code, store_history=True, allow_stdin=False, stop_on_error=False)
        stdout, stderr, start = [], [], time.time()
        while True:
            if time.time() - start > effective_timeout:
                self._km.interrupt_kernel()
                return f'[ERROR] Execution timed out'
            try:
                msg = self._client.get_iopub_msg(timeout=1.0)
            except queue.Empty:
                continue
            if msg.get('parent_header', {}).get('msg_id') != msg_id: continue
            mt, c = msg.get('msg_type'), msg.get('content', {})
            if mt == 'stream': (stdout if c.get('name') == 'stdout' else stderr).append(c.get('text', ''))
            elif mt == 'error': stderr.append(str(c.get('traceback', [])))
            elif mt in {'execute_result', 'display_data'}: 
                if txt := c.get('data', {}).get('text/plain'): stdout.append(txt)
            elif mt == 'status' and c.get('execution_state') == 'idle': break
        return ''.join(stdout) if stdout else (''.join(stderr) if stderr else '[WARN] No output')
    
    def close(self):
        if self._owns_kernel and self._km: self._km.shutdown_kernel(now=True)
    
    def reset(self):
        self.execute('%reset -f\nimport math, numpy, sympy, mpmath, itertools, collections\nmpmath.mp.dps = 64\n')

class AIMO3Tool:
    def __init__(self, timeout, prompt, sandbox):
        self.timeout, self.prompt, self.sandbox = timeout, prompt, sandbox
    
    @property
    def tool_config(self): 
        return ToolNamespaceConfig(name='python', description=self.prompt, tools=[])
    
    def process_sync_plus(self, message):
        code = message.content[0].text
        if 'print' not in code.splitlines()[-1] and not 'import' in code.splitlines()[-1]: 
             code += f'\nprint({code.splitlines()[-1]})'
        out = self.sandbox.execute(code)
        auth = Author(role=Role.TOOL, name='python')
        return [Message(author=auth, content=[TextContent(text=out)]).with_recipient('assistant')]

In [None]:
class DiversitySolver:
    def __init__(self, cfg):
        self.cfg = cfg
        self.port = 8000
        self.template = AIMO3Template() 
        self.encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
        self.stop_ids = self.encoding.stop_tokens_for_assistant_actions()
        self._start_server()
        self.client = OpenAI(base_url=f'http://localhost:{self.port}/v1', api_key='sk-local', timeout=900)
        self.sandbox_pool = queue.Queue()
        self._init_sandboxes()
        self.time_start = time.time()
        self.problems_remaining = 50
        
    def _start_server(self):
        # Eagle3 speculative decoding config
        spec_config = (
            '{"method":"eagle3",'
            '"model":"/kaggle/input/download-eagle3/wenliang1990/gpt-oss-120b-eagle3-aimo3",'
            '"num_speculative_tokens":3,'
            '"draft_tensor_parallel_size":1}'
        )
        cmd = [sys.executable, '-m', 'vllm.entrypoints.openai.api_server',
               '--seed', str(self.cfg.seed),
               '--model', self.cfg.model_path, '--served-model-name', self.cfg.served_model_name,
               '--tensor-parallel-size', '1', '--max-num-seqs', str(self.cfg.batch_size),
               '--gpu-memory-utilization', str(self.cfg.gpu_memory_utilization), '--port', str(self.port),
               '--dtype', self.cfg.dtype, '--kv-cache-dtype', self.cfg.kv_cache_dtype,
               '--max-model-len', str(self.cfg.context_tokens), '--disable-log-stats', 
               '--enable-prefix-caching', "--speculative-config", spec_config]
        
        self.proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, start_new_session=True)
        
        # Wait for server
        print("Waiting for server...")
        start = time.time()
        while time.time() - start < 180:
            try:
                self.client.models.list()
                print(f"Server ready in {time.time()-start:.1f}s")
                return
            except:
                time.sleep(2)
        raise RuntimeError("Server timeout")
        
    def _init_sandboxes(self):
        # Initialize in parallel
        with ThreadPoolExecutor(max_workers=self.cfg.workers) as ex:
            futs = [ex.submit(AIMO3Sandbox, 6) for _ in range(self.cfg.workers)]
            for f in as_completed(futs):
                self.sandbox_pool.put(f.result())

    def _extract_boxed(self, text):
        # Enhanced extraction
        if not text: return None
        for pattern in [r'\\boxed\s*\{\s*([0-9]+)\s*\}', r'final\s+answer\s+is\s*([0-9]+)']:
            m = re.findall(pattern, text, re.IGNORECASE)
            if m: return int(m[-1])
        return None

    def _attempt(self, problem, role_prompt, idx, stop_evt, deadline):
        if stop_evt.is_set() or time.time() > deadline: return None
        
        sandbox = self.sandbox_pool.get()
        tool = AIMO3Tool(6, self.cfg.tool_prompt, sandbox)
        
        # Jitter temperature based on role
        local_temp = self.cfg.temperature
        if "brute-force" in role_prompt: local_temp = 1.0 # High creativity for code generation
        if "theoretical" in role_prompt: local_temp = 0.7 # Strict for formalism
        
        msgs = self.template.apply_chat_template(role_prompt, problem, tool.tool_config)
        conv = Conversation.from_messages(msgs)
        
        ans = None
        entropy = 0.0
        
        try:
            for i in range(self.cfg.turns):
                if stop_evt.is_set() or time.time() > deadline: break
                
                prompt_ids = self.encoding.render_conversation_for_completion(conv, Role.ASSISTANT)
                max_tokens = self.cfg.context_tokens - len(prompt_ids)
                if max_tokens < 512: break
                
                stream = self.client.completions.create(
                    model=self.cfg.served_model_name, prompt=prompt_ids, temperature=local_temp,
                    max_tokens=max_tokens, stream=True, stop_token_ids=self.stop_ids, 
                    logprobs=5, min_p=self.cfg.min_p
                )
                
                tok_buf, txt_buf = [], []
                logprobs_buf = []
                
                for chunk in stream:
                    if stop_evt.is_set() or time.time() > deadline: break
                    if chunk.choices[0].token_ids:
                        tok_buf.extend(chunk.choices[0].token_ids)
                        txt_buf.append(chunk.choices[0].text)
                        if getattr(chunk.choices[0], 'logprobs', None):
                           logprobs_buf.extend(chunk.choices[0].logprobs.top_logprobs)
                           
                    if '}' in chunk.choices[0].text:
                         # Quick scan
                         recent = "".join(txt_buf[-40:])
                         if self._extract_boxed(recent) is not None:
                             ans = self._extract_boxed(recent)
                             break
                
                if ans is not None: break
                if not tok_buf: break
                
                # Update entropy
                if logprobs_buf:
                    ents = [-sum(math.exp(v)*math.log(math.exp(v)) for v in d.values() if math.exp(v)>0) for d in logprobs_buf if d]
                    entropy = sum(ents)/len(ents) if ents else 0
                
                new_msgs = self.encoding.parse_messages_from_completion_tokens(tok_buf, Role.ASSISTANT)
                conv.messages.extend(new_msgs)
                
                last = new_msgs[-1]
                if last.channel == 'final': 
                    ans = self._extract_boxed(last.content[0].text)
                    break
                
                if last.recipient == 'python':
                    resp = tool.process_sync_plus(last)
                    conv.messages.extend(resp)
                    
        except Exception as e:
            print(f"Err {idx}: {e}")
        finally:
            sandbox.reset()
            self.sandbox_pool.put(sandbox)
            
        return {'role': role_prompt[:20], 'answer': ans, 'entropy': entropy}

    def solve(self, problem):
        print(f"Problem: {problem[:50]}...")
        
        # Time Management
        elapsed = time.time() - self.time_start
        budget = max(self.cfg.base_problem_timeout, 
                    min(self.cfg.notebook_limit - elapsed - (self.problems_remaining-1)*self.cfg.base_problem_timeout, 
                        self.cfg.high_problem_timeout))
        deadline = time.time() + budget
        
        results = []
        valid_answers = []
        stop_evt = threading.Event()
        
        with ThreadPoolExecutor(max_workers=self.cfg.workers) as ex:
            futures = []
            # Launch diversity attempts
            num_attempts = min(self.cfg.attempts, len(self.cfg.active_roles))
            for i in range(num_attempts):
                role = self.cfg.active_roles[i]
                futures.append(ex.submit(self._attempt, problem, role, i, stop_evt, deadline))
            
            # Monitor
            for f in as_completed(futures):
                try:
                    r = f.result()
                    if r: 
                        results.append(r)
                        if r['answer'] is not None:
                            valid_answers.append(r['answer'])
                            # Early stopping if we have a strong consensus (e.g., 3 same answers)
                            # But with diversity, we might want to wait a bit longer.
                            c = Counter(valid_answers)
                            if c.most_common(1)[0][1] >= 4:
                                stop_evt.set()
                except Exception as e:
                    print(f"Future fail: {e}")
                    
        self.problems_remaining = max(0, self.problems_remaining - 1)
        
        # Diversity Voting
        votes = defaultdict(float)
        for r in results:
            if r['answer'] is not None:
                # Weights:
                # - Theoretical/Formalist: 1.2x (Harder to fake)
                # - Hacker: 1.1x (If code worked)
                # - Architect: 1.0x
                weight = 1.0
                if "theoretical" in r['role']: weight = 1.3
                if "specialize" in r['role']: weight = 1.1 # Hacker prompt
                
                # Entropy penalty
                weight *= (1.0 / (1.0 + r['entropy']))
                
                votes[r['answer']] += weight
        
        if votes:
            final = max(votes.items(), key=lambda x: x[1])[0]
            print(f"Final Answer: {final} (Votes: {len(valid_answers)})")
            return final
        
        print("Final Answer: 0 (No valid answers)")
        return 0

In [None]:
solver = DiversitySolver(CFG)

def predict(id_: pl.DataFrame, question: pl.DataFrame, answer: Optional[pl.DataFrame] = None) -> pl.DataFrame:
    # Kaggle competition entry point
    id_val = id_.item(0)
    q_text = question.item(0)
    
    gc.collect()
    final_ans = solver.solve(q_text)
    gc.collect()
    
    return pl.DataFrame({'id': id_val, 'answer': final_ans})

# Launch Inference Server
inference_server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)

if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    inference_server.serve()
else:
    inference_server.run_local_gateway(
        ('/kaggle/input/ai-mathematical-olympiad-progress-prize-3/test.csv',)
    )