In [1]:
%pip uninstall --yes 'keras' 'matplotlib' 'scikit-learn' 'tensorflow'

Found existing installation: keras 3.10.0
Uninstalling keras-3.10.0:
  Successfully uninstalled keras-3.10.0
Found existing installation: matplotlib 3.10.0
Uninstalling matplotlib-3.10.0:
  Successfully uninstalled matplotlib-3.10.0
Found existing installation: scikit-learn 1.6.1
Uninstalling scikit-learn-1.6.1:
  Successfully uninstalled scikit-learn-1.6.1
Found existing installation: tensorflow 2.19.0
Uninstalling tensorflow-2.19.0:
  Successfully uninstalled tensorflow-2.19.0
Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import sys
import subprocess
import warnings
warnings.simplefilter('ignore')

def set_env(input_archive, temp_dir):
    if not os.path.exists(temp_dir):
        os.makedirs(temp_dir, exist_ok=True)
        subprocess.run(['tar', '-xzf', input_archive, '-C', temp_dir], check=True)
    
    subprocess.run([
        sys.executable, 
        '-m', 
        'pip', 
        'install', 
        '--no-index', 
        '--find-links', 
        f'{temp_dir}/wheels', 
        'unsloth', 
        'trl', 
        'vllm', 
        'openai_harmony'
    ], 
    stdout=subprocess.DEVNULL,
    stderr=subprocess.DEVNULL,
    check=True)

set_env(
    input_archive='/kaggle/input/aimo-3-utils/wheels.tar.gz', 
    temp_dir='/kaggle/tmp/setup'
)

subprocess.run(['ls', '/kaggle/tmp/setup/tiktoken_encodings'])

cl100k_base.tiktoken
o200k_base.tiktoken


CompletedProcess(args=['ls', '/kaggle/tmp/setup/tiktoken_encodings'], returncode=0)

In [3]:
os.environ['TRANSFORMERS_NO_TF'] = '1'
os.environ['TRANSFORMERS_NO_FLAX'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['TRITON_PTXAS_PATH'] = '/usr/local/cuda/bin/ptxas'
os.environ['TIKTOKEN_ENCODINGS_BASE'] = '/kaggle/tmp/setup/tiktoken_encodings'

import gc
import re
import math
import time
import queue
import threading
import contextlib
from collections import deque
from typing import Iterable, Optional
from jupyter_client import KernelManager
from collections import Counter, defaultdict
from concurrent.futures import as_completed, ThreadPoolExecutor

import pandas as pd
import polars as pl
import shutil

from openai import OpenAI

from openai_harmony import (
    HarmonyEncodingName, 
    load_harmony_encoding, 
    SystemContent, 
    ReasoningEffort, 
    ToolNamespaceConfig, 
    Author, 
    Message, 
    Role, 
    TextContent, 
    Conversation
)

from transformers import set_seed
import kaggle_evaluation.aimo_3_inference_server

class CFG:

    system_prompt = (
        'You are a world-class International Mathematical Olympiad (IMO) competitor. '
        'The final answer must be a non-negative integer between 0 and 99999. '
        'You must place the final integer answer inside \\boxed{}.'
        'Use \\boxed{} exactly once at the very end (never for intermediate results).'
        'If you cannot finish within time, output your best verified result anyway as \\boxed{N}.'
    )
    
    tool_prompt = (
        'Use this tool to execute Python code. '
        'The environment is a stateful Jupyter notebook. '
        'You must use print() to output results. '

        # --- preview / slicing safety ---
        'Preview helper available: use head(x,k) to safely preview dict/list/df. '
        'Never write something[:k] unless you are sure it is a list/tuple; '
        'for dict always use list(d.items())[:k] or head(d,k). '

        # --- huge integer safety ---
        'Do NOT do str(x), len(str(x)), or f"{x}" on huge integers. '
        'Prefer x.bit_length(), x % mod, gcd, or modular arithmetic. '
        'For large exponents, ALWAYS use pow(a, e, mod); NEVER call pow(a, huge_e) without mod. '
        'If you need to print a huge integer, use p(x) or pint(x, mod=...) for safe summaries. '

        # --- number theory / correctness ---
        'When divisors or prime factors matter, use sympy.factorint(n). '
        'Use explicit namespaces for number theory helpers '
        '(e.g. math.gcd / math.lcm or sympy.gcd); do NOT use bare gcd/lcm names. '

        # --- performance & safety rules ---
        'Never use while True; all loops must have explicit bounds. '
        'Complexity budget: keep each Python call fast (<~2 seconds). '
        'Start with small bounds and scale up only if needed. '
        'Avoid large nested loops or wide brute-force scans; '
        'if a scan is slow, reduce bounds or switch methods '
        '(e.g. modular arithmetic, factorization, sieving, caching). '

        # --- timeout avoidance ---
        'Batch work: do NOT call python repeatedly for small steps; '
        'write one cell that computes all needed values. '
        'Before any scan or loop, start with a tiny bound (<=200 or <=2000), '
        'time it, then expand gradually (x2/x3) only if fast. '
        'If execution times out, do NOT rerun the same code; '
        'shrink bounds, add caching, or change the algorithm. '
        'If computing many candidates or ratios, use caching, early-break, '
        'and avoid list comprehensions that call expensive functions.'
    )

    preference_prompt = (
        'You have access to `math`, `numpy` and `sympy` to solve the problem.'
        'Prefer verifiable approaches: reduce to modular arithmetic / factorization / small candidate sets. '
        'If an argument depends on choosing the best among many integers, define a candidate set and a coverage strategy (prove a bound or do a bounded scan + verification).'
    )

    # --- NEW: planner (separate session) ---
    planner_system_prompt = (
        'You are an expert IMO problem-solving PLANNER. '
        'Your job is to produce a short plan to guide another solver. '
        'Strict rule: do NOT state any computed values for g(c), p, q, or the remainder. '
        'Do NOT solve the problem. Do NOT output any final answer or \\boxed{}. '
        "Do NOT include any concluding sentence of the form 'therefore the answer is ...' or 'so remainder is ...'. "
        'You may mention problem constants (e.g., 2025, 2025!, M) only as symbols.'
        'Do NOT write Python code. Output must be concise and actionable.'
    )

    planner_prompt = (
        'Output EXACTLY this template, no extra text:\n'
        'PLAN:\n'
        '- ...\n'
        '- ...\n'
        '- ...\n'
        '- ...\n'
        '- ...\n'
        'DIGEST:\n'
        '- ...\n'
        '<<<END>>>\n'
        'Rules:\n'
        '- PLAN has 5-8 bullets, each <=120 chars.\n '
        '- Each PLAN bullet must be an action (verb-first), not a conclusion.\n '
        '- DIGEST is exactly 1 bullet, <=256 chars.\n '
        '- The line "<<<END>>>" must be on its own line (no other text on that line).\n '
        '- "PLAN:" and "DIGEST:" must NOT be prefixed by "-" or "*". They must be standalone headers.\n '
        '- Plan only: do NOT compute the final answer, do NOT include any final numeric result, do NOT write \\boxed{}.\n '
        '- No explanations, no meta talk, no code.\n '
        '- Do NOT include the words: analysis, final, assistant, user, rewrite, template.\n '    
        '- Must include at least ONE new verification/search tactic not used in the immediately previous attempt '
        '(e.g., widen a bound, change enumeration order, add an independent check, prove a missing lemma).\n'
        '- Must explicitly list the most recent failure mode (from history) AND one concrete prevention rule '
        '(imperative form, e.g., "Do not print huge ints; use bit_length/mod", "Do not assume X without a small-case check").'
        '- DIGEST must NOT restate PLAN bullets; it should compress the approach into ONE actionable sentence.\n'
        '- DIGEST must NOT contain the marker "<<<END>>>". The marker line is standalone.\n'
    )

    plan_enabled = False
    digest_enabled = False
    plan_sanitize = True
    plan_context_limit = 1536
    plan_digest_limit = 256
    plan_max_tokens = 384
    plan_temperature = 0.2

    warmup_compile_cache = False
    compile_cache_src = "/kaggle/input/gpt-oss-120b-cache-compile/torch_compile_cache"
    compile_cache_dst = "/root/.cache/vllm/torch_compile_cache"

    vllm_port = 8000
    served_model_name = 'gpt-oss'
    model_path = '/kaggle/input/gpt-oss-120b/transformers/default/1'

    # not working!! P5 => 62140 ‚ùå
    # served_model_name = 'gpt-oss-sft'
    # model_path = '/kaggle/input/gpt-oss-sft-aimo3/transformers/default/1'
    
    kv_cache_dtype = 'fp8_e4m3'
    dtype = 'auto'

    high_problem_timeout = 900
    base_problem_timeout = 300

    notebook_limit = 17400
    server_timeout = 180

    session_timeout = 960
    jupyter_timeout = 7    # 6
    sandbox_timeout = 3

    stream_interval = 200
    context_tokens = 65536
    buffer_tokens = 512
    search_tokens = 256
    top_logprobs = 5
    batch_size = 256
    early_stop = 4
    attempts = 8
    workers = 16
    turns = 128
    seed = 42

    gpu_memory_utilization = 0.96
    temperature = 1.0
    min_p = 0.02

    select_policy = 'vote'  # 'entropy'
    penalty_err_enabled = False
    penalty_err_alpha = 0.1

    # --- NEW: vote bonus for "verification signal" ---
    # none | calls | ok_calls
    bonus_py_call = "none"
    bonus_py_alpha = 0.05           # ÊØè‰∏™(ÊúâÊïà)py call ÁöÑ‰πòÊ≥ïÂ•ñÂä±Á≥ªÊï∞
    bonus_py_cap = 4                # ÊúÄÂ§öËÆ°ÂÖ•Â§öÂ∞ë‰∏™ callÔºàÈÅøÂÖçÂ•ñÂä±Âà∑Â§™Â§öÔºâ

    debug = True
    debug_req = True
    debug_resp = True
    debug_limit = 3000
    debug_cols = ['Log', 'Plan', 'PlanRaw', 'PlanSanitized', 'PlanDigest']

def _delete(name: str):
    if name is not None and name != "" and name in globals(): del globals()[name]

def _fmt_time(seconds: float) -> str:
    s = int(round(max(0.0, seconds)))
    m, s = divmod(s, 60)
    return f"{m}:{s:02d}"

set_seed(CFG.seed)

In [4]:
class AIMO3Sandbox:

    _port_lock = threading.Lock()
    _next_port = 50000

    @classmethod
    def _get_next_ports(cls, count: int = 5) -> list[int]:
        with cls._port_lock:
            ports = list(range(cls._next_port, cls._next_port + count))
            cls._next_port += count
            return ports

    def __init__(self, timeout: float):
        self._default_timeout = timeout
        self._owns_kernel = False
        self._client = None
        self._km = None
        
        ports = self._get_next_ports(5)

        env = os.environ.copy()
        env['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'
        env['PYDEVD_WARN_EVALUATION_TIMEOUT'] = '0'
        env['JUPYTER_PLATFORM_DIRS'] = '1'
        env['PYTHONWARNINGS'] = 'ignore'
        env['MPLBACKEND'] = 'Agg'

        self._km = KernelManager()
        self._km.shell_port = ports[0]
        self._km.iopub_port = ports[1]
        self._km.stdin_port = ports[2]
        self._km.hb_port = ports[3]
        self._km.control_port = ports[4]
        self._km.start_kernel(env=env, extra_arguments=['--Application.log_level=CRITICAL'])

        self._client = self._km.blocking_client()
        self._client.start_channels()
        self._client.wait_for_ready(timeout=self._default_timeout)
        self._owns_kernel = True

        self._helper_imports = """
import math
import numpy
import sympy
import itertools
import collections
import mpmath
import sys
import re

mpmath.mp.dps = 64

# ---- mitigate Python 3.11 huge-int str digit limit if available ----
try:
    if hasattr(sys, "set_int_max_str_digits"):
        sys.set_int_max_str_digits(0)  # unlimited (best-effort)
except Exception:
    pass
"""

        self._helper_methods = """
# ---- safe formatting helpers ----
def _safe_int_str(n):
    try:
        return str(n)
    except Exception:
        try:
            return f"<int bit_length={int(n.bit_length())}>"
        except Exception:
            return "<int>"

def _safe_atom(x):
    if isinstance(x, int):
        # Avoid huge-int str; represent very large ints compactly
        bl = x.bit_length()
        if bl >= 4096:
            return f"<int bit_length={bl}>"
        return x
    return x

def head(x, k=10):
    \"\"\"Safe preview: works for dict/list/tuple/set/pandas/polars objects.\"\"\"
    try:
        if isinstance(x, dict):
            items = list(x.items())[:k]
            return [(_safe_atom(a), _safe_atom(b)) for a, b in items]
        if isinstance(x, (list, tuple)):
            return [_safe_atom(v) for v in x[:k]]
        if isinstance(x, set):
            return [_safe_atom(v) for v in list(x)[:k]]
        # pandas / polars DataFrame-like
        if hasattr(x, "head"):
            return x.head(k)
    except Exception as _e:
        return f"<head error: {_e}>"
    # fallback: string (safe)
    try:
        s = repr(x)
        return s[:2000] + ("..." if len(s) > 2000 else "")
    except Exception:
        return f"<{type(x).__name__}>"

def pint(x, mod=None, name=None):
    \"\"\"Safe summary for ints: avoids huge-int str.\"\"\"
    if not isinstance(x, int):
        print(f"{name + ': ' if name else ''}{type(x).__name__}")
        return
    bl = x.bit_length()
    if mod is None:
        msg = f"int(bit_length={bl})"
    else:
        try:
            msg = f"int(bit_length={bl}, mod={mod} => {x % mod})"
        except Exception:
            msg = f"int(bit_length={bl}, mod={mod} => <error>)"
    print(f"{name + ': ' if name else ''}{msg}")

def p(x, name=None, k=10):
    \"\"\"Safe print: dict/list previews via head(); ints via pint(); otherwise repr-trunc.\"\"\"
    if isinstance(x, int):
        pint(x, name=name)
        return
    if isinstance(x, dict):
        try:
            h = head(x, k)
            print(f"{name + ': ' if name else ''}dict(len={len(x)}), head={h}")
        except Exception as _e:
            print(f"{name + ': ' if name else ''}dict(<error: {_e}>)")
        return
    if isinstance(x, (list, tuple, set)):
        try:
            h = head(x, k)
            print(f"{name + ': ' if name else ''}{type(x).__name__}(len={len(x)}), head={h}")
        except Exception as _e:
            print(f"{name + ': ' if name else ''}{type(x).__name__}(<error: {_e}>)")
        return
    try:
        s = repr(x)
        if len(s) > 2000:
            s = s[:2000] + "..."
        print(f"{name + ': ' if name else ''}{s}")
    except Exception:
        print(f"{name + ': ' if name else ''}<{type(x).__name__}>")
"""

        self.execute([self._helper_imports, self._helper_methods])

    def _format_error(self, traceback: list[str]) -> str:
        clean_lines = []
        for frame in traceback:
            clean_frame = re.sub(r'\x1b\[[0-9;]*m', '', frame)
            if 'File "' in clean_frame and 'ipython-input' not in clean_frame:
                continue
            clean_lines.append(clean_frame)
        return ''.join(clean_lines)

    def execute(self, code: str | Iterable[str], timeout: float | None = None) -> str:
        client = self._client
        effective_timeout = timeout or self._default_timeout

        if isinstance(code, (list, tuple)):
            code = "\n\n".join(
                c.rstrip() for c in code if isinstance(c, str) and c.strip()
            )

        msg_id = client.execute(
            code, 
            store_history=True, 
            allow_stdin=False, 
            stop_on_error=False
        )

        stdout_parts = []
        stderr_parts = []
        
        start_time = time.time()

        while True:
            elapsed = time.time() - start_time

            if elapsed > effective_timeout:
                self._km.interrupt_kernel()

                return f'[ERROR] Execution timed out after {effective_timeout} seconds'

            try:
                msg = client.get_iopub_msg(timeout=1.0)

            except queue.Empty:
                continue

            if msg.get('parent_header', {}).get('msg_id') != msg_id:
                continue

            msg_type = msg.get('msg_type')
            content = msg.get('content', {})

            if msg_type == 'stream':
                text = content.get('text', '')

                if content.get('name') == 'stdout':
                    stdout_parts.append(text)

                else:
                    stderr_parts.append(text)

            elif msg_type == 'error':
                traceback_list = content.get('traceback', [])

                stderr_parts.append(self._format_error(traceback_list))

            elif msg_type in {'execute_result', 'display_data'}:
                data = content.get('data', {})
                text = data.get('text/plain')

                if text:
                    stdout_parts.append(text if text.endswith('\n') else f'{text}\n')

            elif msg_type == 'status':
                if content.get('execution_state') == 'idle':
                    break

        stdout = ''.join(stdout_parts)
        stderr = ''.join(stderr_parts)

        if stderr:
            return f'{stdout.rstrip()}\n{stderr}' if stdout else stderr

        return stdout if stdout.strip() else '[WARN] No output. Use print() to see results.'

    def close(self):
        with contextlib.suppress(Exception):
            if self._client:
                self._client.stop_channels()

        if self._owns_kernel and self._km is not None:
            with contextlib.suppress(Exception):
                self._km.shutdown_kernel(now=True)

            with contextlib.suppress(Exception):
                self._km.cleanup_resources()

    def reset(self):
        self.execute(['%reset -f\n', self._helper_imports, self._helper_methods])

    def __del__(self):
        self.close()

In [5]:
class AIMO3Tool:

    def __init__(self, local_jupyter_timeout: float, tool_prompt: str, sandbox=None):
        self._local_jupyter_timeout = local_jupyter_timeout
        self._tool_prompt = tool_prompt
        self._jupyter_session = sandbox
        
        self._owns_session = sandbox is None
        
        self._execution_lock = threading.Lock()
        self._init_lock = threading.Lock()

    def _ensure_session(self):

        if self._jupyter_session is None:
            with self._init_lock:
                if self._jupyter_session is None:
                    self._jupyter_session = AIMO3Sandbox(timeout=self._local_jupyter_timeout)

    def _ensure_last_print(self, code: str) -> str:

        lines = code.strip().split('\n')

        if not lines:
            return code

        last_line = lines[-1].strip()

        if 'print' in last_line or 'import' in last_line:
            return code

        if not last_line:
            return code

        if last_line.startswith('#'):
            return code

        lines[-1] = 'print(' + last_line + ')'

        return '\n'.join(lines)

    @property
    def instruction(self) -> str:
        return self._tool_prompt

    @property
    def tool_config(self) -> ToolNamespaceConfig:
        return ToolNamespaceConfig(
            name='python', 
            description=self.instruction, 
            tools=[]
        )

    def _make_response(self, output: str, channel: str | None = None) -> Message:
        content = TextContent(text=output)
        author = Author(role=Role.TOOL, name='python')
        message = Message(author=author, content=[content]).with_recipient('assistant')

        if channel:
            message = message.with_channel(channel)

        return message

    def process_sync_plus(self, message: Message) -> list[Message]:
        self._ensure_session()
        raw_script = message.content[0].text
        final_script = self._ensure_last_print(raw_script)

        with self._execution_lock:
            try:
                output = self._jupyter_session.execute(final_script)

            except TimeoutError as exc:
                output = f'[ERROR] {exc}'

        return [self._make_response(output, channel=message.channel)]

In [6]:
class AIMO3Logger:
    def __init__(self, cfg):
        self.cfg = cfg

    def get_debug_snippet(self, text: str) -> str:
        limit = self.cfg.debug_limit
        if not text or len(text) <= limit:
            return text or ""
        head = text[:100]
        tail_len = limit - 100
        tail = text[-tail_len:]
        return f"{head}\n ... \n{tail}"

    def format_markdown(self, text: str, mode: str = "quote") -> str:
        if not text:
            return ""
        lines = text.split('\n')
        escaped_lines = [f"\\{line}" if line.startswith('#') else line for line in lines]
        processed_text = '\n'.join(escaped_lines)
        if mode in ["markdown", "text", "python"]:
            return f"```{mode}\n{processed_text}\n```\n"
        if mode == "quote":
            return '\n'.join([f"> {line}" for line in escaped_lines]) + "\n"
        if mode == "":
            return processed_text + "\n"
        return f"```\n{processed_text}\n```\n"

    def log_planner_block(self, plan_raw: str, plan_sanitized: str, plan_digest: str) -> str:
        raw_snip = self.get_debug_snippet(plan_raw)
        san_snip = self.get_debug_snippet(plan_sanitized)
        digest = plan_digest.strip()

        out = []
        out.append("### Planner Output (raw)\n")
        out.append(self.format_markdown(raw_snip, mode="text"))
        out.append("### Planner Output (sanitized)\n")
        out.append(self.format_markdown(san_snip, mode="text"))
        out.append("### Plan Digest\n")
        out.append(self.format_markdown(digest, mode="text"))
        return "".join(out)

    def write_debug_logs(self, detailed_results, vote_dataframe, problem, problem_id="UNK", problem_time=""):
        if not self.cfg.debug:
            return
        try:
            summary_lines = ["\n## Summary Stats\n"]
            if detailed_results:
                df = pd.DataFrame(detailed_results)
                cols = [c for c in df.columns if c not in self.cfg.debug_cols]
                summary_lines.append(df[cols].to_markdown(index=False))
                summary_lines.append("\n\n")

            if not vote_dataframe.empty:
                summary_lines.append("## Vote Counts\n")
                summary_lines.append(vote_dataframe.to_markdown(index=False))
                summary_lines.append("\n")

            final_log_content = [f"# Problem ID: {problem_id}\n"]
            final_log_content.append(f"Problem spent time: **{problem_time}**\n\n")
            final_log_content.append(f"**Problem:**\n{self.format_markdown(problem)}\n")
            final_log_content.append(f"**system_prompt:**\n{self.format_markdown(self.cfg.system_prompt)}\n")
            final_log_content.append(f"**tool_prompt:**\n{self.format_markdown(self.cfg.tool_prompt)}\n")
            final_log_content.append(f"**preference_prompt:**\n{self.format_markdown(self.cfg.preference_prompt)}\n")
            final_log_content.append(f"**planner_system_prompt:**\n{self.format_markdown(self.cfg.planner_system_prompt)}\n")
            final_log_content.append(f"**planner_prompt:**\n{self.format_markdown(self.cfg.planner_prompt)}\n")
            final_log_content.append(f"**plan_enabled**: {self.cfg.plan_enabled}, "
                                     f"**digest_enabled**: {self.cfg.digest_enabled}, "
                                     f"**penalty_err_enabled**: {self.cfg.penalty_err_enabled}, "
                                     f"**penalty_err_alpha**: {self.cfg.penalty_err_alpha}, "
                                     f"**bonus_py_call**: {self.cfg.bonus_py_call}, "
                                     f"**bonus_py_alpha**: {self.cfg.bonus_py_alpha}, "
                                     f"**served_model_name: **{self.cfg.served_model_name}**\n")
            final_log_content.extend(summary_lines)
            final_log_content.append("\n===\n")

            sorted_results = sorted(detailed_results, key=lambda x: x['Attempt'])
            for res in sorted_results:
                log_content = res.get('Log', '')
                if log_content:
                    final_log_content.append(log_content)
                    final_log_content.append("\n===\n")

            output_path = f"{problem_id}.md"
            with open(output_path, "w", encoding="utf-8") as f:
                f.write("".join(final_log_content))
            print(f"Debug log written to {output_path}")
        except Exception as e:
            print(f"Failed to write debug log: {e}")

In [7]:
class AIMO3Server:
    def __init__(self, cfg, port: int = 8000):
        self.cfg = cfg
        self.port = getattr(self.cfg, "vllm_port", port)
        self.base_url = f"http://localhost:{self.port}/v1"
        self.api_key = "sk-local"
        self.client = OpenAI(base_url=self.base_url, api_key=self.api_key, timeout=self.cfg.session_timeout)
        self.log_file = None
        
        self._preload_model_weights()
        self.process = self._start()
        self._wait_ready()

    # copy vLLM compile cache if available
    def _warmup_compile_cache(self):
        if not getattr(self.cfg, "warmup_compile_cache", False):
            return
        src = getattr(self.cfg, "compile_cache_src", "")
        dst = getattr(self.cfg, "compile_cache_dst", "")
        if not src or not dst:
            return
        if os.path.exists(src):
            os.makedirs(os.path.dirname(dst), exist_ok=True)
            try:
                shutil.copytree(src, dst, dirs_exist_ok=True)
            except Exception:
                pass

    def _preload_model_weights(self) -> None:
        print(f'Loading model weights from {self.cfg.model_path} into OS Page Cache...')
        start_time = time.time()
        
        files_to_load = []
        total_size = 0
    
        for root, _, files in os.walk(self.cfg.model_path):
            for file_name in files:
                file_path = os.path.join(root, file_name)
    
                if os.path.isfile(file_path):
                    files_to_load.append(file_path)
                    total_size += os.path.getsize(file_path)
    
        def _read_file(path: str) -> None:
            with open(path, 'rb') as file_object:
                while file_object.read(1024 * 1024 * 1024):
                    pass
    
        with ThreadPoolExecutor(max_workers=self.cfg.workers) as executor:
            list(executor.map(_read_file, files_to_load))
    
        elapsed = time.time() - start_time
        print(f'Processed {len(files_to_load)} files ({total_size / 1e9:.2f} GB) in {elapsed:.2f} seconds.\n')
    
    def _start(self) -> subprocess.Popen:
        if self.cfg.warmup_compile_cache:
            self._warmup_compile_cache()

        cmd = [
            sys.executable, '-m', 'vllm.entrypoints.openai.api_server', 
            '--seed', str(self.cfg.seed), 
            '--model', self.cfg.model_path, 
            '--served-model-name', self.cfg.served_model_name, 
            '--tensor-parallel-size', '1', 
            '--max-num-seqs', str(self.cfg.batch_size), 
            '--gpu-memory-utilization', str(self.cfg.gpu_memory_utilization), 
            '--host', '0.0.0.0', 
            '--port', str(self.port), 
            '--dtype', self.cfg.dtype, 
            '--kv-cache-dtype', self.cfg.kv_cache_dtype, 
            '--max-model-len', str(self.cfg.context_tokens), 
            '--stream-interval', str(self.cfg.stream_interval), 
            '--async-scheduling', 
            '--disable-log-stats', 
            '--enable-prefix-caching'
        ]
    
        self.log_file = open('vllm_server.log', 'w')
        return subprocess.Popen(
            cmd, stdout=self.log_file, stderr=subprocess.STDOUT, start_new_session=True
        )
    
    def _wait_ready(self):
        print('Waiting for vLLM server...')
        start_time = time.time()
    
        for _ in range(self.cfg.server_timeout):
            return_code = self.process.poll()
    
            if return_code is not None:
                self.log_file.flush()
                with open('vllm_server.log', 'r') as log_file:
                    logs = log_file.read()
                raise RuntimeError(f'Server died with code {return_code}. Full logs:\n{logs}\n')
    
            try:
                self.client.models.list()
                elapsed = time.time() - start_time
                print(f'Server is ready (took {elapsed:.2f} seconds).\n')
                return
            except Exception:
                time.sleep(1)
        raise RuntimeError('Server failed to start (timeout).\n')

    def __del__(self):
        if self.process is not None:
            with contextlib.suppress(Exception):
                self.process.terminate()
                self.process.wait()
        if self.log_file is not None:
            with contextlib.suppress(Exception):
                self.log_file.close()

In [8]:
class AIMO3Planner:
    """
    Robust planner for Harmony-format models:
    - Collect token_ids and parse via harmony encoding to avoid 'assistantfinal' artifacts.
    - Enforce PLAN/DIGEST template with at most one repair.
    - Never return empty digest; never return non-bullet long paragraphs as plan.
    """

    def __init__(self, cfg, port: int = 8000):
        self.cfg = cfg
        self.port = getattr(self.cfg, "vllm_port", port)
        self.base_url = f"http://localhost:{self.port}/v1"
        self.api_key = "sk-local"
        self.client = OpenAI(base_url=self.base_url, api_key=self.api_key, timeout=self.cfg.session_timeout)

        self.encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
        self.stop_token_ids = self.encoding.stop_tokens_for_assistant_actions()

    # ----------------- helpers -----------------

    def _build_history_block(self, history: list[dict]) -> str:
        if not history:
            return "ATTEMPT HISTORY: (none)\n"
        keep = getattr(self.cfg, "planner_history_keep", 8)
        maxc = getattr(self.cfg, "planner_digest_max_chars", 256)
        lines = ["ATTEMPT HISTORY (structured):"]
        for r in history[-keep:]:
            dig = (r.get("PlanDigest") or "").replace("\n", " ").strip()
            if len(dig) > maxc:
                dig = dig[:maxc] + "..."
            lines.append(
                f"- Attempt {r.get('Attempt')}: "
                f"Answer={r.get('Answer')}, Entropy={float(r.get('Entropy', 1e9)):.3f}, "
                f"PyCalls={int(r.get('Python Calls', 0) or 0)}, PyErr={int(r.get('Python Errors', 0) or 0)}; "
                f"PlanDigest={dig}"
            )
        return "\n".join(lines) + "\n"

    def _render_prompt_ids(self, system_prompt: str, user_text: str):
        sys_content = (
            SystemContent.new()
            .with_model_identity(system_prompt)
            .with_reasoning_effort(reasoning_effort=ReasoningEffort.LOW)
            .with_tools(ToolNamespaceConfig(name="none", description="", tools=[]))
        )
        messages = [
            Message.from_role_and_content(Role.SYSTEM, sys_content),
            Message.from_role_and_content(Role.USER, user_text),
        ]
        conv = Conversation.from_messages(messages)
        return self.encoding.render_conversation_for_completion(conv, Role.ASSISTANT)

    def _decode_from_token_ids(self, token_ids: list[int]) -> str:
        if not token_ids:
            return ""
        msgs = self.encoding.parse_messages_from_completion_tokens(token_ids, Role.ASSISTANT)
        # concatenate all assistant text contents
        parts = []
        for m in msgs:
            if not m.content:
                continue
            for c in m.content:
                if hasattr(c, "text") and c.text:
                    parts.append(c.text)
        return "".join(parts).strip()

    def _stream_completion(self, system_prompt: str, user_text: str, seed: int, max_tokens: int, temperature: float):
        prompt_ids = self._render_prompt_ids(system_prompt, user_text)
        req_timeout = min(30.0, float(self.cfg.session_timeout))
        try:
            stream = self.client.completions.create(
                model=self.cfg.served_model_name,
                temperature=float(temperature),
                max_tokens=int(max_tokens),
                prompt=prompt_ids,
                seed=int(seed),
                stream=True,
                timeout=req_timeout,
                extra_body={
                    "min_p": self.cfg.min_p,
                    "stop_token_ids": self.stop_token_ids,
                    "return_token_ids": True,
                },
            )
        except Exception as e:
            # ËøîÂõûÁ©∫ plannerÔºåËÆ© solver Ëµ∞ fallback/Êó† plan Ë∑ØÂæÑ
            return "", f"[PLANNER_STREAM_CREATE_ERROR] {e}"

        token_buf = []
        text_buf = []
        try:
            for chunk in stream:
                # token_ids is the reliable path for harmony parsing
                tids = getattr(chunk.choices[0], "token_ids", None)
                if tids:
                    token_buf.extend(tids)
                t = chunk.choices[0].text
                if t:
                    text_buf.append(t)
        finally:
            stream.close()

        # raw text is only for debugging
        raw_text = "".join(text_buf).strip()
        parsed_text = self._decode_from_token_ids(token_buf)
        return parsed_text, raw_text

    def _bulletize(self, text: str, max_lines: int = 8) -> str:
        t = (text or "").strip()
        if not t:
            return ""
        lines = [ln.strip() for ln in t.splitlines() if ln.strip()]
        out = []
        for ln in lines[:max_lines]:
            ln = re.sub(r"^\s*(analysis|final|commentary)\s*", "", ln, flags=re.IGNORECASE).strip()
            if not ln:
                continue
            if not ln.startswith(("-", "*", "‚Ä¢")):
                ln = "- " + ln
            out.append(ln)
        return "\n".join(out).strip()

    def _strip_answerish_lines(self, plan_part: str) -> str:
        lines = plan_part.splitlines()
        bad_kw = ("\\boxed", "remainder", "mod", "final answer", "=", "so each", "therefore", "template", "assistantfinal", "assistant")
        out = []
        for ln in lines:
            low = ln.lower()
            # Âè™ËøáÊª§‚ÄúÁ≠îÊ°àÂûãÊÄªÁªìÂè•‚ÄùÔºå‰øùÁïôÊ≠£Â∏∏Êï∞Â≠¶Á≠âÂºèÁöÑÂèØËÉΩÊÄß‰ºöÊúâËØØ‰º§Ôºå‰ΩÜÂú® planner Âú∫ÊôØÂà©Â§ß‰∫éÂºä
            if any(k in low for k in bad_kw) and any(ch.isdigit() for ch in ln):
                continue
            out.append(ln)
        return "\n".join(out).strip()

    def _extract_plan_and_digest(self, text: str) -> tuple[str, str]:
        """
        Accept headers with optional bullet prefixes:
          PLAN: or - PLAN:
          DIGEST: or - DIGEST:
        """
        t = (text or "").strip()
        if not t:
            return "", ""

        # normalize <<<END>>> cut
        t = re.split(r"(?m)^\s*<<<END>>>\s*$", t)[0].strip()

        # find headers (allow optional bullet prefix)
        plan_hdr = re.search(r"(?im)^\s*(?:[-*‚Ä¢]\s*)?PLAN\s*:\s*$", t)
        dig_hdr = re.search(r"(?im)^\s*(?:[-*‚Ä¢]\s*)?DIGEST\s*:\s*$", t)

        if plan_hdr and dig_hdr and dig_hdr.start() > plan_hdr.end():
            plan_block = t[plan_hdr.end():dig_hdr.start()].strip()
            dig_block = t[dig_hdr.end():].strip()
        else:
            # fallback: try split by substring if headers are inline
            if "DIGEST:" in t:
                left, right = t.split("DIGEST:", 1)
                plan_block = left
                dig_block = right
                plan_block = re.sub(r"(?im)^\s*(?:[-*‚Ä¢]\s*)?PLAN\s*:\s*", "", plan_block).strip()
            else:
                plan_block, dig_block = t, ""

        # keep only bullet lines for plan
        plan_lines = []
        for ln in plan_block.splitlines():
            ln = ln.strip()
            if ln.startswith(("-", "*", "‚Ä¢")):
                plan_lines.append(ln)
        plan_part = "\n".join(plan_lines).strip()

        # digest: first bullet line
        digest_line = ""
        for ln in dig_block.splitlines():
            ln = ln.strip()
            if ln.startswith(("-", "*", "‚Ä¢")):
                digest_line = ln.lstrip("-*‚Ä¢ ").strip()
                break
            
        plan_part = re.sub(r"(?i)\b<<<END>>>\b", "", plan_part).strip()
        digest_line = re.sub(r"(?i)\s*\b<<<END>>>\b\s*$", "", digest_line).strip()

        return plan_part, digest_line

    def _make_digest_fallback(self, plan_text: str) -> str:
        maxc = getattr(self.cfg, "planner_digest_max_chars", 256)
        lines = [ln.strip() for ln in (plan_text or "").splitlines() if ln.strip()]
        bullets = [ln.lstrip("-*‚Ä¢ ").strip() for ln in lines if ln.startswith(("-", "*", "‚Ä¢"))]
        s = (bullets[0] if bullets else "").strip()
        if not s:
            s = "Try a different approach; enforce small scans + modular checks + caching."
        return s[:maxc].strip()

    def _is_good(self, plan_part: str, digest: str) -> bool:
        if not plan_part:
            return False
        nbul = sum(1 for ln in plan_part.splitlines() if ln.strip().startswith(("-", "*", "‚Ä¢")))
        if nbul < 5:
            return False
        return bool(digest.strip())

    # ----------------- main API -----------------

    def gen_plan(self, problem_text: str, history: list[dict], attempt_index: int):
        # NOTE: planner should be short; do NOT set serial_plan_max_tokens too large.
        max_tokens = min(int(self.cfg.serial_plan_max_tokens), 512)
        temp = float(getattr(self.cfg, "serial_aux_temperature", 0.7))

        user_prompt = (
            f"{self.cfg.planner_prompt}\n\n"
            f"PROBLEM:\n{problem_text}\n\n"
            f"{self._build_history_block(history)}\n"
            f"Output now."
        )

        seed = int((self.cfg.seed + 777) * (attempt_index + 1) ** 2)
        parsed_text, raw_text = self._stream_completion(
            system_prompt=self.cfg.planner_system_prompt,
            user_text=user_prompt,
            seed=seed,
            max_tokens=max_tokens,
            temperature=temp,
        )

        plan_part, digest = self._extract_plan_and_digest(parsed_text)

        # one repair if bad
        if not self._is_good(plan_part, digest):
            repair_user = (
                "FORMATTER TASK. Output ONLY the required template.\n"
                "First line: PLAN:\n"
                "Then 5-8 lines starting with '- '\n"
                "Then: DIGEST:\n"
                "Then exactly one '- ' line (<=256 chars)\n"
                "Then: <<<END>>>\n\n"
                f"PROBLEM:\n{problem_text}\n"
            )
            parsed_text2, raw_text2 = self._stream_completion(
                system_prompt=self.cfg.planner_system_prompt,
                user_text=repair_user,
                seed=seed + 1,
                max_tokens=256,
                temperature=0.0,
            )
            parsed_text = parsed_text2 or parsed_text
            raw_text = raw_text2 or raw_text
            plan_part, digest = self._extract_plan_and_digest(parsed_text)

        # plan_part = self._strip_answerish_lines(plan_part)
        # program-side hard fallback (never return garbage paragraphs)
        if not plan_part.strip():
            plan_part = self._bulletize(parsed_text) or self._bulletize(problem_text)
        if not digest.strip():
            digest = self._make_digest_fallback(plan_part)

        # truncate plan only (digest stays useful)
        # plan_part = plan_part[: self.cfg.serial_context_char_limit].strip()
        # if len(digest) > self.cfg.planner_digest_max_chars:
        #     digest = digest[: self.cfg.planner_digest_max_chars].strip()

        # return both parsed and raw for logging
        plan_sanitized = plan_part  # already bullet-only; treat as sanitized
        plan_raw = raw_text or parsed_text
        return plan_part, digest, plan_raw, plan_sanitized

In [9]:
class AIMO3Solver:

    def __init__(self, cfg, port: int = 8000):
        self.cfg = cfg
        self.port = getattr(self.cfg, "vllm_port", port)
        self.base_url = f"http://localhost:{self.port}/v1"
        self.api_key = 'sk-local'
        self.client = OpenAI(base_url=self.base_url, api_key=self.api_key, timeout=self.cfg.session_timeout)

        self.encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
        self.stop_token_ids = self.encoding.stop_tokens_for_assistant_actions()
        
        self.notebook_start_time = time.time()
        self.problems_remaining = 50
        self.logger = AIMO3Logger(cfg)
        self.planner = AIMO3Planner(cfg)
        self._initialize_kernels()
    
    def _initialize_kernels(self) -> None:
        print(f'Initializing {self.cfg.workers} persistent Jupyter kernels...')
        start_time = time.time()
    
        self.sandbox_pool = queue.Queue()
    
        def _create_sandbox():
            return AIMO3Sandbox(timeout=self.cfg.jupyter_timeout)
    
        with ThreadPoolExecutor(max_workers=self.cfg.workers) as executor:
            futures = [executor.submit(_create_sandbox) for _ in range(self.cfg.workers)]
            for future in as_completed(futures):
                self.sandbox_pool.put(future.result())
    
        elapsed = time.time() - start_time
        print(f'Kernels initialized in {elapsed:.2f} seconds.\n')

    def _scan_for_answer(self, text: str) -> int | None:
        pattern = r'\\boxed\s*\{\s*([0-9,]+)\s*\}'
        matches = re.findall(pattern, text)
    
        if matches:
            try:
                clean_value = matches[-1].replace(',', '')
                value = int(clean_value)
    
                if 0 <= value <= 99999:
                    return value
    
            except ValueError:
                pass
                
        pattern = r'final\s+answer\s+is\s*([0-9,]+)'
        matches = re.findall(pattern, text, re.IGNORECASE)
    
        if matches:
            try:
                clean_value = matches[-1].replace(',', '')
                value = int(clean_value)
    
                if 0 <= value <= 99999:
                    return value
    
            except ValueError:
                pass
    
        return None
    
    def _compute_mean_entropy(self, logprobs_buffer: list) -> float:
        if not logprobs_buffer:
            return float('inf')
    
        total_entropy = 0.0
        token_count = 0
    
        for top_logprobs_dict in logprobs_buffer:
            
            if not isinstance(top_logprobs_dict, dict):
                continue
            
            if not top_logprobs_dict:
                continue
            
            token_entropy = 0.0
            
            for token_str, log_prob in top_logprobs_dict.items():
                prob = math.exp(log_prob)
                
                if prob > 0:
                    token_entropy -= prob * math.log2(prob)
            
            total_entropy += token_entropy
            token_count += 1
    
        if token_count == 0:
            return float('inf')
    
        return total_entropy / token_count

    def _process_attempt(
        self, 
        problem: str, 
        system_prompt: str, 
        attempt_index: int, 
        stop_event: threading.Event, 
        deadline: float,
        problem_id: str,
        plan: str = "",
        plan_digest: str = "",
        plan_raw: str = "",
        plan_sanitized: str = "",
    ) -> dict:
        attempt_log = deque([])
        attempt_start = time.time()

        if stop_event.is_set() or time.time() > deadline:
            print(f"Problem: {problem_id} TIMEOUT!")
            return {
                'Attempt': attempt_index + 1, 
                'Answer': None, 
                'Python Calls': 0, 
                'Python Errors': 0, 
                'Response Length': 0, 
                'Entropy': float('inf'),
                'Log': "\n".join(attempt_log)
            }
    
        local_tool = None
        sandbox = None
        python_calls = 0
        python_errors = 0
        total_tokens = 0
        final_answer = None
        
        logprobs_buffer = []
        attempt_seed = int(math.pow(self.cfg.seed + attempt_index, 2))
    
        try:
            sandbox = self.sandbox_pool.get(timeout=self.cfg.sandbox_timeout)
            local_tool = AIMO3Tool(
                local_jupyter_timeout=self.cfg.jupyter_timeout, 
                tool_prompt=self.cfg.tool_prompt, 
                sandbox=sandbox
            )
            encoding = self.encoding

            aug = ""
            if plan:
                aug += f"\n\n=== CURRENT ATTEMPT PLAN ===\n{plan}\n"
            if aug:
                aug += "\nFollow the plan.\n"
            full_problem = problem + aug
            if self.cfg.debug and self.logger:
                attempt_log.append(self.logger.log_planner_block(plan_raw, plan_sanitized, plan_digest))
                attempt_log.append("### Planner Augmentation\n")
                attempt_log.append(f"{self.logger.format_markdown(aug, mode='text')}\n")

            sys_content = (
                SystemContent.new()
                .with_model_identity(system_prompt)
                .with_reasoning_effort(reasoning_effort=ReasoningEffort.HIGH)
                .with_tools(local_tool.tool_config)
            )
            messages = [
                Message.from_role_and_content(Role.SYSTEM, sys_content),
                Message.from_role_and_content(Role.USER, full_problem), # problem
            ]
            conversation = Conversation.from_messages(messages)

            for turn_i in range(self.cfg.turns):
                if stop_event.is_set() or time.time() > deadline:
                    break
    
                prompt_ids = encoding.render_conversation_for_completion(conversation, Role.ASSISTANT)
                max_tokens = self.cfg.context_tokens - len(prompt_ids)
    
                if max_tokens < self.cfg.buffer_tokens:
                    break

                if self.cfg.debug and self.cfg.debug_req:
                    # convert prompt_ids (tensors) back to readable text
                    # which includes LLM special symbols like <|im_start|>
                    full_request_text = encoding.decode(prompt_ids)
                    
                    snippet = self.logger.get_debug_snippet(full_request_text)
                    formatted_req = self.logger.format_markdown(snippet)
                    attempt_log.append(f"### Turn {turn_i} - Raw Request to Model:")
                    attempt_log.append(formatted_req)

                try:
                    req_timeout = max(1.0, deadline - time.time())
                    stream = self.client.completions.create(
                        model=self.cfg.served_model_name, 
                        temperature=self.cfg.temperature, 
                        logprobs=self.cfg.top_logprobs, 
                        max_tokens=max_tokens, 
                        prompt=prompt_ids, 
                        seed=attempt_seed, 
                        stream=True, 
                        timeout=req_timeout,
                        extra_body={
                            'min_p': self.cfg.min_p, 
                            'stop_token_ids': self.stop_token_ids, 
                            'return_token_ids': True
                        },
                    )
                except Exception as e:
                    # stream create failed -> don't hang; record & break this attempt turn
                    if self.cfg.debug:
                        attempt_log.append(
                            f"### Turn {turn_i} - Stream Create Failed (timeout={req_timeout:.1f}s)\n"
                            f"{self.logger.format_markdown(str(e), mode='text')}\n"
                        )
                    break

                full_response_text = ""
                try:
                    token_buffer = []
                    text_chunks = []
    
                    for chunk in stream:
                        if stop_event.is_set() or time.time() > deadline:
                            break
    
                        new_tokens = chunk.choices[0].token_ids
                        new_text = chunk.choices[0].text
    
                        if new_tokens:
                            token_buffer.extend(new_tokens)
                            total_tokens += len(new_tokens)
                            text_chunks.append(new_text)
                            if self.cfg.debug and self.cfg.debug_resp:
                                full_response_text += new_text

                            chunk_logprobs = chunk.choices[0].logprobs
                            if chunk_logprobs is not None:
                                if chunk_logprobs.top_logprobs:
                                    logprobs_buffer.extend(chunk_logprobs.top_logprobs)
    
                        if '}' in new_text:
                            search_text = ''.join(text_chunks[-self.cfg.search_tokens:])
                            answer = self._scan_for_answer(search_text)
    
                            if answer is not None:
                                final_answer = answer
                                break
    
                finally:
                    stream.close()

                if self.cfg.debug and full_response_text:
                    attempt_log.append(f"### Turn {turn_i} - Model Response:")
                    formatted_resp = self.logger.format_markdown(full_response_text)
                    attempt_log.append(formatted_resp)

                if final_answer is not None:
                    break
    
                if not token_buffer:
                    break
    
                new_messages = encoding.parse_messages_from_completion_tokens(token_buffer, Role.ASSISTANT)
                conversation.messages.extend(new_messages)
                last_message = new_messages[-1]
    
                if last_message.channel == 'final':
                    answer_text = last_message.content[0].text
                    final_answer = self._scan_for_answer(answer_text)
                    break
    
                if last_message.recipient == 'python':
                    python_calls += 1
                    tool_responses = local_tool.process_sync_plus(last_message)
                    response_text = tool_responses[0].content[0].text

                    if self.cfg.debug:
                        code_content = last_message.content[0].text
                        attempt_log.append(f"### Turn {turn_i} - Python Call:")
                        attempt_log.append(f"```python\n{code_content}\n```\n")

                        attempt_log.append(f"### Turn {turn_i} - Python Output:")
                        snippet_out = self.logger.get_debug_snippet(response_text)
                        formatted_out = self.logger.format_markdown(snippet_out, mode="text")
                        attempt_log.append(f"{formatted_out}\n")

                    if response_text.startswith('[ERROR]') or 'Traceback' in response_text or 'Error:' in response_text:
                        python_errors += 1
    
                    conversation.messages.extend(tool_responses)
    
        except Exception as exc:
            python_errors += 1
            if self.cfg.debug:
                attempt_log.append(f"\n**EXCEPTION:** {str(exc)}\n")

        finally:
            if sandbox is not None:
                sandbox.reset()
                self.sandbox_pool.put(sandbox)
    
        mean_entropy = self._compute_mean_entropy(logprobs_buffer)
        attempt_elapsed = time.time() - attempt_start
        attempt_time = _fmt_time(attempt_elapsed)
        if self.cfg.debug:
            attempt_log.appendleft(f"Attempt spent time: **{attempt_time}**\n")
            attempt_log.appendleft(f"## Attempt {attempt_index + 1}\n")
    
        return {
            'Attempt': attempt_index + 1, 
            'Response Length': total_tokens, 
            'Python Calls': python_calls, 
            'Python Errors': python_errors, 
            'Entropy': mean_entropy, 
            'Answer': final_answer,
            'Plan': plan,
            'PlanDigest': plan_digest,
            'PlanRaw': plan_raw,
            'PlanSanitized': plan_sanitized,
            'Log': "\n".join(attempt_log),
            'Time': attempt_time
        }
    
    def _select_answer_by_vote(self, detailed_results: list) -> tuple[pd.DataFrame, int]:
        stats = defaultdict(lambda: {'votes': 0, 'calls': 0})
        for result in detailed_results:
            answer = result['Answer']

            if answer is not None:
                stats[answer]['votes'] += 1
                stats[answer]['calls'] += result['Python Calls']

        sorted_stats = sorted(
            stats.items(), 
            key=lambda item: (item[1]['votes'], item[1]['calls']), 
            reverse=True
        )

        vote_data = []

        for answer, data in sorted_stats:
            vote_data.append((answer, data['votes'], data['calls']))

        vote_dataframe = pd.DataFrame(vote_data, columns=['Answer', 'Votes', 'Calls'])
        display(vote_dataframe)

        final_answer = sorted_stats[0][0]
        final_votes = sorted_stats[0][1]['votes']
        final_calls = sorted_stats[0][1]['calls']

        print(f'\nFinal Result: {final_answer} | Votes: {final_votes} | Calls: {final_calls}\n')
        return vote_dataframe, final_answer

    def _calc_score(self, result) -> float:
        entropy = result['Entropy']
        py_calls = int(result.get('Python Calls', 0) or 0)
        py_errs = int(result.get('Python Errors', 0) or 0)
        score = 1.0 / max(entropy, 1e-9)

        penalty_err_enabled = getattr(self.cfg, "penalty_err_enabled", False)
        penalty_err_alpha = float(getattr(self.cfg, "penalty_err_alpha", 0.05))
        if penalty_err_enabled and penalty_err_alpha > 0:
            score *= 1.0 / (1.0 + self.cfg.penalty_err_alpha * py_errs)
    
        mode = getattr(self.cfg, "vote_py_bonus_mode", "none")
        if mode and mode != "none":
            if mode == "calls":
                sig = py_calls
            elif mode == "ok_calls":
                sig = max(0, py_calls - py_errs)
            else:
                sig = 0

            cap = int(getattr(self.cfg, "vote_py_bonus_cap", 4))
            alpha = float(getattr(self.cfg, "vote_py_bonus_alpha", 0.05))
            sig = min(sig, cap)
            score *= (1.0 + alpha * sig)
        return score

    def _select_answer_by_score(self, detailed_results: list) -> tuple[pd.DataFrame, int]:
        answer_score = defaultdict(float)
        answer_votes = defaultdict(int)

        for result in detailed_results:
            answer = result['Answer']
            
            if answer is not None:
                score = self._calc_score(result)
                answer_score[answer] += score
                answer_votes[answer] += 1

        scored_answers = []

        for answer, total_score in answer_score.items():
            scored_answers.append({
                'answer': answer, 
                'votes': answer_votes[answer], 
                'score': total_score
            })

        scored_answers.sort(key=lambda x: x['score'], reverse=True)

        vote_data = []

        for item in scored_answers:
            vote_data.append((
                item['answer'], 
                item['votes'], 
                item['score']
            ))

        vote_dataframe = pd.DataFrame(
            vote_data, 
            columns=['Answer', 'Votes', 'Score']
        )

        vote_dataframe = vote_dataframe.round({'Score': 3})
        display(vote_dataframe)
        
        if not scored_answers:
            print('\nFinal Answer: 0\n')
            return vote_dataframe, 0

        final_answer = scored_answers[0]['answer']    
        print(f'\nFinal Answer: {final_answer}\n')
        return vote_dataframe, final_answer

    def _select_answer(self, detailed_results: list) -> tuple[pd.DataFrame, int]:
        if self.cfg.select_policy == 'vote':
            return self._select_answer_by_vote(detailed_results)
        else:
            return self._select_answer_by_score(detailed_results)

    def solve_problem(self, problem: str, problem_id: str = "UNK") -> int:
        print(f'\nProblem: {problem}\n')
        problem_start = time.time()
        
        user_input = f'{problem} {self.cfg.preference_prompt}'
    
        elapsed_global = time.time() - self.notebook_start_time
        time_left = self.cfg.notebook_limit - elapsed_global
        problems_left_others = max(0, self.problems_remaining - 1)
        reserved_time = problems_left_others * self.cfg.base_problem_timeout
    
        budget = time_left - reserved_time
        budget = min(budget, self.cfg.high_problem_timeout)
        budget = max(budget, self.cfg.base_problem_timeout)
    
        deadline = time.time() + budget
    
        print(f'Budget: {budget:.2f} seconds | Deadline: {deadline:.2f}\n')
    
        tasks = []
    
        for attempt_index in range(self.cfg.attempts):
            tasks.append((self.cfg.system_prompt, attempt_index))
    
        detailed_results = []
        valid_answers = []
        stop_event = threading.Event()
        executor = ThreadPoolExecutor(max_workers=self.cfg.workers)
    
        try:
            futures = []    
            for (system_prompt, attempt_index) in tasks:
                future = executor.submit(
                    self._process_attempt, 
                    user_input, 
                    system_prompt, 
                    attempt_index, 
                    stop_event, 
                    deadline,
                    problem_id
                )
                futures.append(future)
    
            for future in as_completed(futures):
                try:
                    result = future.result()
                    detailed_results.append(result)
                    if result['Answer'] is not None:
                        valid_answers.append(result['Answer'])

                    counts = Counter(valid_answers).most_common(1)
                    if counts and counts[0][1] >= self.cfg.early_stop:
                        stop_event.set()    
                        for f in futures: f.cancel()
                        break
    
                except Exception as exc:
                    print(f'Future failed: {exc}')
                    continue
    
        finally:
            stop_event.set()
            executor.shutdown(wait=True, cancel_futures=True)
            self.problems_remaining = max(0, self.problems_remaining - 1)

        if detailed_results:
            results_dataframe = pd.DataFrame(detailed_results)
            results_dataframe['Entropy'] = results_dataframe['Entropy'].round(3)
            results_dataframe['Answer'] = results_dataframe['Answer'].astype('Int64')
            
            cols = [c for c in results_dataframe.columns if not c in self.cfg.debug_cols]
            display(results_dataframe[cols])

        problem_elapsed = time.time() - problem_start
        problem_time = _fmt_time(problem_elapsed)
        if not valid_answers:
            print('\nResult: 0\n')
            vote_data, final_answer = pd.DataFrame(columns=['Answer', 'Votes', 'Score']), 0
        else:
            vote_data, final_answer = self._select_answer(detailed_results)
        
        print(f"Problem ID: {problem_id}, spent time: {problem_time}, problems_remaining = {self.problems_remaining}")
        self.logger.write_debug_logs(detailed_results, vote_data, problem, problem_id, problem_time)
        return final_answer

    def __del__(self):
        if hasattr(self, 'sandbox_pool'):
            while not self.sandbox_pool.empty():
                try:
                    sb = self.sandbox_pool.get_nowait()
                    sb.close()
                except Exception:
                    pass

In [10]:
_delete("server")
server = AIMO3Server(CFG)

_delete("solver")
solver = AIMO3Solver(CFG)

Loading model weights from /kaggle/input/gpt-oss-120b/transformers/default/1 into OS Page Cache...
Processed 26 files (65.28 GB) in 65.82 seconds.

Waiting for vLLM server...
Server is ready (took 133.29 seconds).

Initializing 16 persistent Jupyter kernels...
Kernels initialized in 2.91 seconds.



In [11]:
predict_answers = {}

def predict(id_: pl.DataFrame, question: pl.DataFrame, answer: Optional[pl.DataFrame] = None) -> pl.DataFrame:
    global predict_answers
    id_value = id_.item(0)
    question_text = question.item(0)
    gc.disable()
    final_answer = solver.solve_problem(question_text, problem_id=str(id_value))
    predict_answers[id_value] = final_answer
    gc.enable()
    gc.collect()
    return pl.DataFrame({'id': id_value, 'answer': final_answer})

inference_server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)

def test():
    global predict_answers

    # test_csv = '/kaggle/input/ai-mathematical-olympiad-progress-prize-3/test.csv'
    test_csv = '/kaggle/input/ai-mathematical-olympiad-progress-prize-3/reference.csv'
    # test_csv = '/kaggle/input/aimo-p3-hard/test2.csv'
    # test_csv = '/kaggle/input/aimo-p3-hard/test3.csv'
    # test_csv = '/kaggle/input/aimo-p3-hard/p5.csv'
    # test_csv = '/kaggle/input/aimo-p3-hard/p10.csv'

    inference_server.run_local_gateway((test_csv,))

    df = pd.read_csv(test_csv)
    real_answers = dict(zip(df["id"], df["answer"])) if "answer" in df.columns else {}
    correct_count = 0
    total_count = 0
    # Check accuracy if ground truth available
    for (id, predict_answer) in predict_answers:
        if id in real_answers:
            total_count += 1
            real_answer = real_answers[id]
            is_correct = (predict_answer == real_answer)
            if is_correct:
                correct_count += 1
            status = "‚úÖ" if is_correct else "‚ùå"
            print(f"Problem ID: {id} -- Predict Answer: {predict_answer} | Ground Truth: {real_answer} | {status}")
        else:
            print(f"Problem ID: {id} -- Predict Answer: {predict_answer}")
    print(f"üìä Running Accuracy: {correct_count}/{total_count} ({100*correct_count/total_count:.1f}%)")

In [12]:
if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
    CFG.debug = False
    inference_server.serve()
else:
    test()

global_spent = time.time() - solver.notebook_start_time
print(f"Total {len(predict_answers)} problems finished in {_fmt_time(global_spent)}!")


Problem: On a blackboard, Ken starts off by writing a positive integer $n$ and then applies the following move until he first reaches $1$. Given that the number on the board is $m$, he chooses a base $b$, where $2 \leq b \leq m$, and considers the unique base-$b$ representation of $m$,
\begin{equation*}
    m = \sum_{k = 0}^\infty a_k \cdot b^k
\end{equation*}
where $a_k$ are non-negative integers and $0 \leq a_k < b$ for each $k$. Ken then erases $m$ on the blackboard and replaces it with $\sum\limits_{k = 0}^\infty a_k$.

Across all choices of $1 \leq n \leq 10^{10^5}$, the largest possible number of moves Ken could make is $M$. What is the remainder when $M$ is divided by $10^{5}$?

Budget: 900.00 seconds | Deadline: 1769157048.26



Unnamed: 0,Attempt,Response Length,Python Calls,Python Errors,Entropy,Answer,Time
0,7,4912,4,0,0.73,32193,1:03
1,5,5973,5,0,0.721,32193,1:13
2,1,6754,1,0,0.772,32193,1:20
3,6,6595,9,0,0.721,32193,1:23


Unnamed: 0,Answer,Votes,Calls
0,32193,4,19



Final Result: 32193 | Votes: 4 | Calls: 19

Problem ID: 42d360, spent time: 1:24, problems_remaining = 49
Debug log written to 42d360.md

Problem: Let $\mathcal{F}$ be the set of functions $\alpha \colon \mathbb{Z}\to \mathbb{Z}$ for which there are only finitely many $n \in \mathbb{Z}$ such that $\alpha(n) \neq 0$. 

For two functions $\alpha$ and $\beta$ in $\mathcal{F}$, define their product $\alpha\star\beta$ to be $\sum\limits_{n\in\mathbb{Z}} \alpha(n)\cdot \beta(n)$. Also, for $n\in\mathbb{Z}$, define a shift operator $S_n \colon \mathcal{F}\to \mathcal{F}$ by $S_n(\alpha)(t)=\alpha(t+n)$ for all $t \in \mathbb{Z}$.

A function $\alpha \in \mathcal{F}$ is called \emph{shifty} if 
\begin{itemize}
    \item $\alpha(m)=0$ for all integers $m<0$ and $m>8$ and
    \item There exists $\beta \in \mathcal{F}$ and integers $k \neq l$ such that for all $n \in \mathbb{Z}$
    \begin{equation*}
        S_n(\alpha)\star\beta =
        \begin{cases}
            1 & n \in \{k,l\} \\
         

Unnamed: 0,Attempt,Response Length,Python Calls,Python Errors,Entropy,Answer,Time
0,6,14991,10,1,0.805,44,2:32
1,7,15778,19,2,0.726,44,2:41
2,4,16273,13,1,0.738,44,2:46
3,2,14868,17,6,0.734,44,2:58


Unnamed: 0,Answer,Votes,Calls
0,44,4,59



Final Result: 44 | Votes: 4 | Calls: 59

Problem ID: dd7f5e, spent time: 3:02, problems_remaining = 48
Debug log written to dd7f5e.md

Problem: Let $n \geq 6$ be a positive integer. We call a positive integer $n$-Norwegian if it has three distinct positive divisors whose sum is equal to $n$. Let $f(n)$ denote the smallest $n$-Norwegian positive integer. Let $M=3^{2025!}$ and for a non-negative integer $c$ define 
\begin{equation*}
    g(c)=\frac{1}{2025!}\left\lfloor \frac{2025! f(M+c)}{M}\right\rfloor.
\end{equation*}
We can write 
\begin{equation*}
    g(0)+g(4M)+g(1848374)+g(10162574)+g(265710644)+g(44636594)=\frac{p}{q}
\end{equation*}
where $p$ and $q$ are coprime positive integers. What is the remainder when $p+q$ is divided by $99991$?

Budget: 900.00 seconds | Deadline: 1769157314.42



Unnamed: 0,Attempt,Response Length,Python Calls,Python Errors,Entropy,Answer,Time
0,4,23524,20,1,0.721,23.0,4:10
1,7,29386,37,3,0.725,52659.0,5:53
2,5,32323,47,0,0.675,,6:06
3,2,38766,18,1,0.735,79.0,7:03
4,6,37626,28,4,0.711,5.0,7:11
5,8,40119,67,3,0.709,8687.0,7:46
6,1,41004,68,8,0.699,85452.0,8:16
7,3,47194,87,6,0.691,8687.0,8:55


Unnamed: 0,Answer,Votes,Calls
0,8687,2,154
1,85452,1,68
2,52659,1,37
3,5,1,28
4,23,1,20
5,79,1,18



Final Result: 8687 | Votes: 2 | Calls: 154

Problem ID: 86e8e5, spent time: 8:55, problems_remaining = 47
Debug log written to 86e8e5.md

Problem: A $500 \times 500$ square is divided into $k$ rectangles, each having integer side lengths. Given that no two of these rectangles have the same perimeter, the largest possible value of $k$ is $\mathcal{K}$. What is the remainder when $k$ is divided by $10^{5}$?

Budget: 900.00 seconds | Deadline: 1769157849.26



Unnamed: 0,Attempt,Response Length,Python Calls,Python Errors,Entropy,Answer,Time
0,8,19852,2,0,0.949,520,3:17
1,7,22859,8,0,0.894,520,3:51
2,4,24718,8,1,0.898,520,4:18
3,6,26311,5,0,0.903,520,4:26


Unnamed: 0,Answer,Votes,Calls
0,520,4,23



Final Result: 520 | Votes: 4 | Calls: 23

Problem ID: a295e9, spent time: 4:27, problems_remaining = 46
Debug log written to a295e9.md

Problem: Alice and Bob are each holding some integer number of sweets. Alice says to Bob: ``If we each added the number of sweets we're holding to our (positive integer) age, my answer would be double yours. If we took the product, then my answer would be four times yours.'' Bob replies: ``Why don't you give me five of your sweets because then both our sum and product would be equal.'' What is the product of Alice and Bob's ages?

Budget: 900.00 seconds | Deadline: 1769158116.54



Unnamed: 0,Attempt,Response Length,Python Calls,Python Errors,Entropy,Answer,Time
0,5,1553,1,0,0.601,50,0:15
1,1,1977,1,0,0.569,50,0:19
2,8,2000,1,0,0.578,50,0:19
3,2,2149,2,0,0.661,50,0:20


Unnamed: 0,Answer,Votes,Calls
0,50,4,5



Final Result: 50 | Votes: 4 | Calls: 5

Problem ID: 92ba6a, spent time: 0:21, problems_remaining = 45
Debug log written to 92ba6a.md

Problem: Let $f \colon \mathbb{Z}_{\geq 1} \to \mathbb{Z}_{\geq 1}$ be a function such that for all positive integers $m$ and $n$, 
\begin{equation*}
    f(m) + f(n) = f(m + n + mn).
\end{equation*}
Across all functions $f$ such that $f(n) \leq 1000$ for all $n \leq 1000$, how many different values can $f(2024)$ take?

Budget: 900.00 seconds | Deadline: 1769158137.78



Unnamed: 0,Attempt,Response Length,Python Calls,Python Errors,Entropy,Answer,Time
0,7,7581,3,0,0.762,580,1:13
1,3,9195,11,0,0.663,580,1:33
2,8,8949,10,1,0.809,580,1:34
3,2,10988,10,0,0.769,580,1:45


Unnamed: 0,Answer,Votes,Calls
0,580,4,34



Final Result: 580 | Votes: 4 | Calls: 34

Problem ID: 9c1c5f, spent time: 1:47, problems_remaining = 44
Debug log written to 9c1c5f.md

Problem: Define a function $f \colon \mathbb{Z}_{\geq 1} \to \mathbb{Z}_{\geq 1}$ by
\begin{equation*}
    f(n) = \sum_{i = 1}^n \sum_{j = 1}^n j^{1024} \left\lfloor\frac1j + \frac{n-i}{n}\right\rfloor.
\end{equation*}
Let $M=2 \cdot 3 \cdot 5 \cdot 7 \cdot 11 \cdot 13$ and let $N = f{\left(M^{15}\right)} - f{\left(M^{15}-1\right)}$. Let $k$ be the largest non-negative integer such that $2^k$ divides $N$. What is the remainder when $2^k$ is divided by $5^7$?

Budget: 900.00 seconds | Deadline: 1769158244.62



Unnamed: 0,Attempt,Response Length,Python Calls,Python Errors,Entropy,Answer,Time
0,4,4040,3,0,0.563,32951,0:37
1,6,5175,6,0,0.532,32951,0:49
2,1,5740,5,0,0.556,32951,0:54
3,2,6002,3,0,0.606,32951,0:56


Unnamed: 0,Answer,Votes,Calls
0,32951,4,17



Final Result: 32951 | Votes: 4 | Calls: 17

Problem ID: 26de63, spent time: 0:57, problems_remaining = 43
Debug log written to 26de63.md

Problem: Let $ABC$ be a triangle with $AB \neq AC$, circumcircle $\Omega$, and incircle $\omega$. Let the contact points of $\omega$ with $BC$, $CA$, and $AB$ be $D$, $E$, and $F$, respectively. Let the circumcircle of $AFE$ meet $\Omega$ at $K$ and let the reflection of $K$ in $EF$ be $K'$. Let $N$ denote the foot of the perpendicular from $D$ to $EF$. The circle tangent to line $BN$ and passing through $B$ and $K$ intersects $BC$ again at $T \neq B$. 
    
Let sequence $(F_n)_{n \geq 0}$ be defined by $F_0 = 0$, $F_1 = 1$ and for $n \geq 2$, $F_n = F_{n-1} + F_{n-2}$. Call $ABC$ $n$\emph{-tastic} if $BD = F_n$, $CD = F_{n+1}$, and $KNK'B$ is cyclic. Across all $n$-tastic triangles, let $a_n$ denote the maximum possible value of $\frac{CT \cdot NB}{BT \cdot NE}$. Let $\alpha$ denote the smallest real number such that for all sufficiently large $n$,

Unnamed: 0,Attempt,Response Length,Python Calls,Python Errors,Entropy,Answer,Time
0,3,12492,20,3,0.548,57447,2:19
1,7,16301,12,1,0.538,57447,2:47
2,2,20695,19,3,0.496,57447,3:36
3,1,29532,38,5,0.64,99989,5:14
4,5,32560,49,10,0.577,0,5:35
5,8,32293,31,4,0.489,57447,6:04


Unnamed: 0,Answer,Votes,Calls
0,57447,4,82
1,0,1,49
2,99989,1,38



Final Result: 57447 | Votes: 4 | Calls: 82

Problem ID: 641659, spent time: 6:10, problems_remaining = 42
Debug log written to 641659.md

Problem: A tournament is held with $2^{20}$ runners each of which has a different running speed. In each race, two runners compete against each other with the faster runner always winning the race. The competition consists of $20$ rounds with each runner starting with a score of $0$. In each round, the runners are paired in such a way that in each pair, both runners have the same score at the beginning of the round. The winner of each race in the $i^{\text{th}}$ round receives $2^{20-i}$ points and the loser gets no points.

At the end of the tournament, we rank the competitors according to their scores. Let $N$ denote the number of possible orderings of the competitors at the end of the tournament. Let $k$ be the largest positive integer such that $10^k$ divides $N$. What is the remainder when $k$ is divided by $10^{5}$?

Budget: 900.00 seconds | D

Unnamed: 0,Attempt,Response Length,Python Calls,Python Errors,Entropy,Answer,Time
0,8,11476,3,1,0.846,21818,1:54
1,3,11680,15,2,0.754,21818,2:05
2,1,12409,11,2,0.774,21818,2:05
3,7,12884,15,0,0.812,21818,2:09


Unnamed: 0,Answer,Votes,Calls
0,21818,4,44



Final Result: 21818 | Votes: 4 | Calls: 44

Problem ID: 424e18, spent time: 2:13, problems_remaining = 41
Debug log written to 424e18.md

Problem: Let $ABC$ be an acute-angled triangle with integer side lengths and $AB<AC$. Points $D$ and $E$ lie on segments $BC$ and $AC$, respectively, such that $AD=AE=AB$. Line $DE$ intersects $AB$ at $X$. Circles $BXD$ and $CED$ intersect for the second time at $Y \neq D$. Suppose that $Y$ lies on line $AD$. There is a unique such triangle with minimal perimeter. This triangle has side lengths $a=BC$, $b=CA$, and $c=AB$. Find the remainder when $abc$ is divided by $10^{5}$.

Budget: 900.00 seconds | Deadline: 1769158805.43



Unnamed: 0,Attempt,Response Length,Python Calls,Python Errors,Entropy,Answer,Time
0,8,12912,10,2,0.574,336,2:08
1,3,14056,9,0,0.605,336,2:20
2,1,15055,10,0,0.616,2688,2:30
3,5,16412,16,0,0.622,336,2:44
4,6,18599,14,1,0.586,336,3:01


Unnamed: 0,Answer,Votes,Calls
0,336,4,49
1,2688,1,10



Final Result: 336 | Votes: 4 | Calls: 49

Problem ID: 0e644e, spent time: 3:02, problems_remaining = 40
Debug log written to 0e644e.md


ValueError: too many values to unpack (expected 2)

In [14]:
print(predict_answers)
test_csv = '/kaggle/input/ai-mathematical-olympiad-progress-prize-3/reference.csv'
df = pd.read_csv(test_csv)
real_answers = dict(zip(df["id"], df["answer"])) if "answer" in df.columns else {}
correct_count = 0
total_count = 0
# Check accuracy if ground truth available
for id in predict_answers:
    predict_answer = predict_answers[id]
    if id in real_answers:
        total_count += 1
        real_answer = real_answers[id]
        is_correct = (predict_answer == real_answer)
        if is_correct:
            correct_count += 1
        status = "‚úÖ" if is_correct else "‚ùå"
        print(f"Problem ID: {id} -- Predict Answer: {predict_answer} | Ground Truth: {real_answer} | {status}")
    else:
        print(f"Problem ID: {id} -- Predict Answer: {predict_answer}")
print(f"üìä Running Accuracy: {correct_count}/{total_count} ({100*correct_count/total_count:.1f}%)")

{'42d360': 32193, 'dd7f5e': 44, '86e8e5': 8687, 'a295e9': 520, '92ba6a': 50, '9c1c5f': 580, '26de63': 32951, '641659': 57447, '424e18': 21818, '0e644e': 336}
Problem ID: 42d360 -- Predict Answer: 32193 | Ground Truth: 32193 | ‚úÖ
Problem ID: dd7f5e -- Predict Answer: 44 | Ground Truth: 160 | ‚ùå
Problem ID: 86e8e5 -- Predict Answer: 8687 | Ground Truth: 8687 | ‚úÖ
Problem ID: a295e9 -- Predict Answer: 520 | Ground Truth: 520 | ‚úÖ
Problem ID: 92ba6a -- Predict Answer: 50 | Ground Truth: 50 | ‚úÖ
Problem ID: 9c1c5f -- Predict Answer: 580 | Ground Truth: 580 | ‚úÖ
Problem ID: 26de63 -- Predict Answer: 32951 | Ground Truth: 32951 | ‚úÖ
Problem ID: 641659 -- Predict Answer: 57447 | Ground Truth: 57447 | ‚úÖ
Problem ID: 424e18 -- Predict Answer: 21818 | Ground Truth: 21818 | ‚úÖ
Problem ID: 0e644e -- Predict Answer: 336 | Ground Truth: 336 | ‚úÖ
üìä Running Accuracy: 9/10 (90.0%)
